diffusers 源码解析(十一)
# 版权所有 2024 HunyuanDiT 作者,Qixun Wang 和 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版("许可证")进行许可;
# 除非符合许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面同意,否则根据许可证分发的软件均按 "原样" 基础提供,
# 不提供任何种类的保证或条件,无论是明示或暗示的。
# 有关许可证的具体条款和条件,请参阅许可证。
from typing import Dict, Optional, Union # 导入字典、可选和联合类型定义
import torch # 导入 PyTorch 库
from torch import nn # 从 PyTorch 导入神经网络模块
from ...configuration_utils import ConfigMixin, register_to_config # 从配置工具导入混合类和注册功能
from ...utils import logging # 从工具包导入日志记录功能
from ...utils.torch_utils import maybe_allow_in_graph # 导入可能允许图形内操作的功能
from ..attention import FeedForward # 从注意力模块导入前馈网络
from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0 # 导入注意力处理器
from ..embeddings import ( # 导入嵌入模块
HunyuanCombinedTimestepTextSizeStyleEmbedding, # 组合时间步、文本、大小和样式的嵌入
PatchEmbed, # 图像补丁嵌入
PixArtAlphaTextProjection, # 像素艺术文本投影
)
from ..modeling_outputs import Transformer2DModelOutput # 导入 2D 变换器模型输出类型
from ..modeling_utils import ModelMixin # 导入模型混合类
from ..normalization import AdaLayerNormContinuous, FP32LayerNorm # 导入自适应层归一化和 FP32 层归一化
logger = logging.get_logger(__name__) # 创建当前模块的日志记录器,禁用 pylint 警告
class AdaLayerNormShift(nn.Module): # 定义自适应层归一化偏移类,继承自 nn.Module
r""" # 类文档字符串,描述类的功能
Norm layer modified to incorporate timestep embeddings. # 归一化层,修改以包含时间步嵌入
Parameters: # 参数说明
embedding_dim (`int`): The size of each embedding vector. # 嵌入向量的大小
num_embeddings (`int`): The size of the embeddings dictionary. # 嵌入字典的大小
"""
def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6): # 初始化方法
super().__init__() # 调用父类初始化方法
self.silu = nn.SiLU() # 定义 SiLU 激活函数
self.linear = nn.Linear(embedding_dim, embedding_dim) # 定义线性层,输入输出维度均为嵌入维度
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) # 定义层归一化
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: # 定义前向传播方法
shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype)) # 计算偏移量
x = self.norm(x) + shift.unsqueeze(dim=1) # 对输入进行归一化并加上偏移
return x # 返回处理后的张量
@maybe_allow_in_graph # 装饰器,可能允许在计算图中使用
class HunyuanDiTBlock(nn.Module): # 定义 Hunyuan-DiT 模型中的变换器块类
r""" # 类文档字符串,描述类的功能
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and # Hunyuan-DiT 模型中的变换器块,允许跳过连接和
QKNorm # QKNorm 功能
# 参数说明部分,定义各参数的类型和作用
Parameters:
dim (`int`): # 输入和输出的通道数
The number of channels in the input and output.
num_attention_heads (`int`): # 多头注意力机制中使用的头数
The number of heads to use for multi-head attention.
cross_attention_dim (`int`, *optional*): # 跨注意力的编码器隐藏状态向量的大小
The size of the encoder_hidden_states vector for cross attention.
dropout (`float`, *optional*, defaults to 0.0): # 用于正则化的丢弃概率
The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): # 前馈网络中使用的激活函数
Activation function to be used in feed-forward.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`): # 是否使用可学习的元素逐个仿射参数进行归一化
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, *optional*, defaults to 1e-6): # 加到归一化层分母的小常数,以防止除以零
A small constant added to the denominator in normalization layers to prevent division by zero.
final_dropout (`bool`, *optional*, defaults to False): # 在最后的前馈层后是否应用最终丢弃
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*): # 前馈块中隐藏层的大小,默认为 None
The size of the hidden layer in the feed-forward block. Defaults to `None`.
ff_bias (`bool`, *optional*, defaults to `True`): # 前馈块中是否使用偏置
Whether to use bias in the feed-forward block.
skip (`bool`, *optional*, defaults to `False`): # 是否使用跳过连接,默认为下块和中块的 False
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
qk_norm (`bool`, *optional*, defaults to `True`): # 在 QK 计算中是否使用归一化,默认为 True
Whether to use normalization in QK calculation. Defaults to `True`.
"""
# 构造函数的定义,初始化各参数
def __init__(
self,
dim: int, # 输入和输出的通道数
num_attention_heads: int, # 多头注意力机制中使用的头数
cross_attention_dim: int = 1024, # 默认的跨注意力维度
dropout=0.0, # 默认的丢弃概率
activation_fn: str = "geglu", # 默认的激活函数
norm_elementwise_affine: bool = True, # 默认使用可学习的仿射参数
norm_eps: float = 1e-6, # 默认的归一化小常数
final_dropout: bool = False, # 默认不应用最终丢弃
ff_inner_dim: Optional[int] = None, # 默认的前馈块隐藏层大小
ff_bias: bool = True, # 默认使用偏置
skip: bool = False, # 默认不使用跳过连接
qk_norm: bool = True, # 默认在 QK 计算中使用归一化
):
# 调用父类构造函数
super().__init__()
# 定义三个块,每个块都有自己的归一化层。
# 注意:新版本发布时,检查 norm2 和 norm3
# 1. 自注意力机制
self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
# 创建自注意力机制的实例
self.attn1 = Attention(
query_dim=dim, # 查询向量的维度
cross_attention_dim=None, # 交叉注意力的维度,未使用
dim_head=dim // num_attention_heads, # 每个头的维度
heads=num_attention_heads, # 注意力头的数量
qk_norm="layer_norm" if qk_norm else None, # 查询和键的归一化方法
eps=1e-6, # 数值稳定性常数
bias=True, # 是否使用偏置
processor=HunyuanAttnProcessor2_0(), # 注意力处理器的实例
)
# 2. 交叉注意力机制
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
# 创建交叉注意力机制的实例
self.attn2 = Attention(
query_dim=dim, # 查询向量的维度
cross_attention_dim=cross_attention_dim, # 交叉注意力的维度
dim_head=dim // num_attention_heads, # 每个头的维度
heads=num_attention_heads, # 注意力头的数量
qk_norm="layer_norm" if qk_norm else None, # 查询和键的归一化方法
eps=1e-6, # 数值稳定性常数
bias=True, # 是否使用偏置
processor=HunyuanAttnProcessor2_0(), # 注意力处理器的实例
)
# 3. 前馈网络
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
# 创建前馈网络的实例
self.ff = FeedForward(
dim, # 输入维度
dropout=dropout, # dropout 比例
activation_fn=activation_fn, # 激活函数
final_dropout=final_dropout, # 最终 dropout 比例
inner_dim=ff_inner_dim, # 内部维度,通常是 dim 的倍数
bias=ff_bias, # 是否使用偏置
)
# 4. 跳跃连接
if skip: # 如果启用跳跃连接
self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True) # 创建归一化层
self.skip_linear = nn.Linear(2 * dim, dim) # 创建线性层
else: # 如果不启用跳跃连接
self.skip_linear = None # 设置为 None
# 将块大小默认为 None
self._chunk_size = None # 初始化块大小
self._chunk_dim = 0 # 初始化块维度
# 从 diffusers.models.attention.BasicTransformerBlock 复制的设置块前馈方法
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# 设置块前馈
self._chunk_size = chunk_size # 设置块大小
self._chunk_dim = dim # 设置块维度
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态
encoder_hidden_states: Optional[torch.Tensor] = None, # 编码器的隐藏状态
temb: Optional[torch.Tensor] = None, # 额外的嵌入
image_rotary_emb=None, # 图像旋转嵌入
skip=None, # 跳跃连接标志
) -> torch.Tensor:
# 注意:以下代码块中的计算总是在归一化之后进行。
# 0. 长跳跃连接
# 如果 skip_linear 不为 None,执行跳跃连接
if self.skip_linear is not None:
# 将当前的隐藏状态与跳跃连接的输出在最后一维上拼接
cat = torch.cat([hidden_states, skip], dim=-1)
# 对拼接后的结果进行归一化处理
cat = self.skip_norm(cat)
# 通过线性层处理归一化后的结果,更新隐藏状态
hidden_states = self.skip_linear(cat)
# 1. 自注意力
# 对当前隐藏状态进行归一化,准备进行自注意力计算
norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct
# 计算自注意力的输出
attn_output = self.attn1(
norm_hidden_states,
image_rotary_emb=image_rotary_emb,
)
# 将自注意力的输出加到隐藏状态上,形成新的隐藏状态
hidden_states = hidden_states + attn_output
# 2. 交叉注意力
# 将交叉注意力的输出加到当前的隐藏状态上
hidden_states = hidden_states + self.attn2(
self.norm2(hidden_states), # 先进行归一化
encoder_hidden_states=encoder_hidden_states, # 使用编码器的隐藏状态
image_rotary_emb=image_rotary_emb, # 传递旋转嵌入
)
# 前馈网络层 ### TODO: 在状态字典中切换 norm2 和 norm3
# 对当前的隐藏状态进行归一化处理,准备进入前馈网络
mlp_inputs = self.norm3(hidden_states)
# 通过前馈网络处理归一化后的输入,更新隐藏状态
hidden_states = hidden_states + self.ff(mlp_inputs)
# 返回最终的隐藏状态
return hidden_states
# 定义 HunyuanDiT2DModel 类,继承自 ModelMixin 和 ConfigMixin
class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
"""
HunYuanDiT: 基于 Transformer 的扩散模型。
继承 ModelMixin 和 ConfigMixin 以与 diffusers 的采样器 StableDiffusionPipeline 兼容。
参数:
num_attention_heads (`int`, *可选*, 默认为 16):
多头注意力的头数。
attention_head_dim (`int`, *可选*, 默认为 88):
每个头的通道数。
in_channels (`int`, *可选*):
输入和输出的通道数(如果输入为 **连续**,需指定)。
patch_size (`int`, *可选*):
输入的补丁大小。
activation_fn (`str`, *可选*, 默认为 `"geglu"`):
前馈网络中使用的激活函数。
sample_size (`int`, *可选*):
潜在图像的宽度。训练期间固定使用,以学习位置嵌入的数量。
dropout (`float`, *可选*, 默认为 0.0):
使用的 dropout 概率。
cross_attention_dim (`int`, *可选*):
clip 文本嵌入中的维度数量。
hidden_size (`int`, *可选*):
条件嵌入层中隐藏层的大小。
num_layers (`int`, *可选*, 默认为 1):
使用的 Transformer 块的层数。
mlp_ratio (`float`, *可选*, 默认为 4.0):
隐藏层大小与输入大小的比率。
learn_sigma (`bool`, *可选*, 默认为 `True`):
是否预测方差。
cross_attention_dim_t5 (`int`, *可选*):
t5 文本嵌入中的维度数量。
pooled_projection_dim (`int`, *可选*):
池化投影的大小。
text_len (`int`, *可选*):
clip 文本嵌入的长度。
text_len_t5 (`int`, *可选*):
T5 文本嵌入的长度。
use_style_cond_and_image_meta_size (`bool`, *可选*):
是否使用风格条件和图像元数据大小。版本 <=1.1 为 True,版本 >= 1.2 为 False
"""
# 注册到配置中
@register_to_config
def __init__(
# 多头注意力的头数,默认为 16
self,
num_attention_heads: int = 16,
# 每个头的通道数,默认为 88
attention_head_dim: int = 88,
# 输入和输出的通道数,默认为 None
in_channels: Optional[int] = None,
# 输入的补丁大小,默认为 None
patch_size: Optional[int] = None,
# 激活函数,默认为 "gelu-approximate"
activation_fn: str = "gelu-approximate",
# 潜在图像的宽度,默认为 32
sample_size=32,
# 条件嵌入层中隐藏层的大小,默认为 1152
hidden_size=1152,
# 使用的 Transformer 块的层数,默认为 28
num_layers: int = 28,
# 隐藏层大小与输入大小的比率,默认为 4.0
mlp_ratio: float = 4.0,
# 是否预测方差,默认为 True
learn_sigma: bool = True,
# clip 文本嵌入中的维度数量,默认为 1024
cross_attention_dim: int = 1024,
# 正则化类型,默认为 "layer_norm"
norm_type: str = "layer_norm",
# t5 文本嵌入中的维度数量,默认为 2048
cross_attention_dim_t5: int = 2048,
# 池化投影的大小,默认为 1024
pooled_projection_dim: int = 1024,
# clip 文本嵌入的长度,默认为 77
text_len: int = 77,
# T5 文本嵌入的长度,默认为 256
text_len_t5: int = 256,
# 是否使用风格条件和图像元数据大小,默认为 True
use_style_cond_and_image_meta_size: bool = True,
):
# 调用父类的初始化方法
super().__init__()
# 根据是否学习 sigma 决定输出通道数
self.out_channels = in_channels * 2 if learn_sigma else in_channels
# 设置注意力头的数量
self.num_heads = num_attention_heads
# 计算内部维度,等于注意力头数量乘以每个头的维度
self.inner_dim = num_attention_heads * attention_head_dim
# 初始化文本嵌入器,用于将输入特征投影到更高维空间
self.text_embedder = PixArtAlphaTextProjection(
# 输入特征维度
in_features=cross_attention_dim_t5,
# 隐藏层大小为输入特征的四倍
hidden_size=cross_attention_dim_t5 * 4,
# 输出特征维度
out_features=cross_attention_dim,
# 激活函数设置为"siluf_fp32"
act_fn="silu_fp32",
)
# 初始化文本嵌入的填充参数,使用随机正态分布初始化
self.text_embedding_padding = nn.Parameter(
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
)
# 初始化位置嵌入,构建图像的补丁嵌入
self.pos_embed = PatchEmbed(
# 补丁的高度
height=sample_size,
# 补丁的宽度
width=sample_size,
# 输入通道数
in_channels=in_channels,
# 嵌入维度
embed_dim=hidden_size,
# 补丁大小
patch_size=patch_size,
# 位置嵌入类型设置为 None
pos_embed_type=None,
)
# 初始化时间和风格嵌入,结合时间步和文本大小
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
# 隐藏层大小
hidden_size,
# 池化投影维度
pooled_projection_dim=pooled_projection_dim,
# 输入序列长度
seq_len=text_len_t5,
# 交叉注意力维度
cross_attention_dim=cross_attention_dim_t5,
# 是否使用风格条件和图像元数据大小
use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
)
# 初始化 HunyuanDiT 块列表
self.blocks = nn.ModuleList(
[
# 为每一层创建 HunyuanDiTBlock
HunyuanDiTBlock(
# 内部维度
dim=self.inner_dim,
# 注意力头数量
num_attention_heads=self.config.num_attention_heads,
# 激活函数
activation_fn=activation_fn,
# 前馈网络内部维度
ff_inner_dim=int(self.inner_dim * mlp_ratio),
# 交叉注意力维度
cross_attention_dim=cross_attention_dim,
# 查询-键归一化开启
qk_norm=True, # 详情见 http://arxiv.org/abs/2302.05442
# 如果当前层数大于层数的一半,则跳过
skip=layer > num_layers // 2,
)
# 遍历层数
for layer in range(num_layers)
]
)
# 初始化输出的自适应层归一化
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
# 初始化输出的线性层,将内部维度映射到输出通道数
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel 中复制的代码,用于融合 QKV 投影,更新为 FusedHunyuanAttnProcessor2_0
def fuse_qkv_projections(self):
"""
启用融合的 QKV 投影。对于自注意力模块,所有投影矩阵(即查询、键、值)都被融合。
对于交叉注意力模块,键和值投影矩阵被融合。
<Tip warning={true}>
该 API 是 🧪 实验性的。
</Tip>
"""
# 初始化原始注意力处理器为 None
self.original_attn_processors = None
# 遍历所有注意力处理器
for _, attn_processor in self.attn_processors.items():
# 检查注意力处理器类名中是否包含 "Added"
if "Added" in str(attn_processor.__class__.__name__):
# 如果包含,则抛出错误,表示不支持融合 QKV 投影
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# 保存当前的注意力处理器以备后用
self.original_attn_processors = self.attn_processors
# 遍历当前模块中的所有子模块
for module in self.modules():
# 检查模块是否为 Attention 类型
if isinstance(module, Attention):
# 对于 Attention 模块,启用投影融合
module.fuse_projections(fuse=True)
# 设置融合的注意力处理器
self.set_attn_processor(FusedHunyuanAttnProcessor2_0())
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 复制
def unfuse_qkv_projections(self):
"""
如果已启用,则禁用融合的 QKV 投影。
<Tip warning={true}>
该 API 是 🧪 实验性的。
</Tip>
"""
# 检查是否有原始注意力处理器
if self.original_attn_processors is not None:
# 恢复为原始注意力处理器
self.set_attn_processor(self.original_attn_processors)
@property
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 复制
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
返回:
`dict` 类型的注意力处理器:一个字典,包含模型中使用的所有注意力处理器,按权重名称索引。
"""
# 初始化处理器字典
processors = {}
# 定义递归添加处理器的函数
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 检查模块是否具有 get_processor 方法
if hasattr(module, "get_processor"):
# 将处理器添加到字典中
processors[f"{name}.processor"] = module.get_processor()
# 遍历子模块
for sub_name, child in module.named_children():
# 递归调用以添加子模块的处理器
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
# 遍历当前模块的所有子模块
for name, module in self.named_children():
# 调用递归函数添加处理器
fn_recursive_add_processors(name, module, processors)
# 返回处理器字典
return processors
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 复制
# 定义设置注意力处理器的方法,接收一个注意力处理器或处理器字典
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
设置用于计算注意力的注意力处理器。
参数:
processor(`dict` of `AttentionProcessor` 或 `AttentionProcessor`):
已实例化的处理器类或将作为处理器设置的处理器类字典
用于**所有** `Attention` 层。
如果 `processor` 是字典,则键需要定义相应的交叉注意力处理器路径。
当设置可训练的注意力处理器时,这强烈建议。
"""
# 计算当前注意力处理器的数量
count = len(self.attn_processors.keys())
# 检查传入的处理器是否为字典,并验证其长度与注意力层数量是否一致
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
# 抛出错误,提示处理器数量与注意力层数量不匹配
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
# 定义递归设置注意力处理器的内部函数
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 检查模块是否具有设置处理器的方法
if hasattr(module, "set_processor"):
# 如果处理器不是字典,直接设置处理器
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 从字典中取出处理器并设置
module.set_processor(processor.pop(f"{name}.processor"))
# 遍历模块的子模块,递归调用设置处理器的方法
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍历当前对象的所有子模块,调用递归设置处理器的方法
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# 定义设置默认注意力处理器的方法
def set_default_attn_processor(self):
"""
禁用自定义注意力处理器,并设置默认的注意力实现。
"""
# 调用设置注意力处理器的方法,使用默认的 HunyuanAttnProcessor2_0
self.set_attn_processor(HunyuanAttnProcessor2_0())
# 定义前向传播的方法,接收多个输入参数
def forward(
self,
hidden_states,
timestep,
encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
image_rotary_emb=None,
controlnet_block_samples=None,
return_dict=True,
# 从 diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 复制的代码
# 定义一个方法以启用前馈层的分块处理,参数为分块大小和维度
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
设置注意力处理器使用 [前馈分块处理](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。
参数:
chunk_size (`int`, *可选*):
前馈层的分块大小。如果未指定,将对每个维度为 `dim` 的张量单独运行前馈层。
dim (`int`, *可选*, 默认为 `0`):
前馈计算应分块的维度。选择 dim=0(批处理)或 dim=1(序列长度)。
"""
# 检查维度是否为 0 或 1
if dim not in [0, 1]:
# 抛出值错误,提示维度设置不当
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# 默认分块大小为 1
chunk_size = chunk_size or 1
# 定义递归函数以设置前馈层的分块处理
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模块具有设置分块前馈的属性,则进行设置
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遍历模块的所有子模块并递归调用
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 遍历当前对象的所有子模块并应用递归函数
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# 从 diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking 复制
def disable_forward_chunking(self):
# 定义递归函数以禁用前馈层的分块处理
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模块具有设置分块前馈的属性,则进行设置
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遍历模块的所有子模块并递归调用
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 遍历当前对象的所有子模块并应用递归函数,禁用分块处理
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
# 版权声明,表明该代码的版权属于 Latte 团队和 HuggingFace 团队
# Copyright 2024 the Latte Team and The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证 2.0 版("许可证")进行授权;
# 除非遵循该许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,否则根据许可证分发的软件是按 "原样" 基础进行的,
# 不提供任何形式的担保或条件,无论是明示的还是暗示的。
# 请参见许可证以了解有关特定语言的权限和限制。
from typing import Optional # 从 typing 模块导入 Optional 类型,用于指示可选参数
import torch # 导入 PyTorch 库
from torch import nn # 从 PyTorch 导入神经网络模块
# 从配置工具导入 ConfigMixin 和注册配置的功能
from ...configuration_utils import ConfigMixin, register_to_config
# 导入与图像嵌入相关的类和函数
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
# 导入基础变换器块的定义
from ..attention import BasicTransformerBlock
# 导入图像块嵌入的定义
from ..embeddings import PatchEmbed
# 导入 Transformer 2D 模型输出的定义
from ..modeling_outputs import Transformer2DModelOutput
# 导入模型混合功能的定义
from ..modeling_utils import ModelMixin
# 导入自适应层归一化的定义
from ..normalization import AdaLayerNormSingle
# 定义一个 3D Transformer 模型类,继承自 ModelMixin 和 ConfigMixin
class LatteTransformer3DModel(ModelMixin, ConfigMixin):
# 设置支持梯度检查点的标志为 True
_supports_gradient_checkpointing = True
"""
一个用于视频类数据的 3D Transformer 模型,相关论文链接:
https://arxiv.org/abs/2401.03048,官方代码地址:
https://github.com/Vchitect/Latte
"""
# 参数说明
Parameters:
# 使用的多头注意力头数,默认为16
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
# 每个头的通道数,默认为88
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
# 输入通道数
in_channels (`int`, *optional*):
The number of channels in the input.
# 输出通道数
out_channels (`int`, *optional*):
The number of channels in the output.
# Transformer块的层数,默认为1
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
# dropout概率,默认为0.0
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
# 用于cross attention的编码器隐藏状态维度
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
# 配置TransformerBlocks的注意力是否包含偏置参数
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
# 潜在图像的宽度(如果输入为离散类型,需指定)
sample_size (`int`, *optional*):
The width of the latent images (specify if the input is **discrete**).
# 在训练期间固定,用于学习位置嵌入数量。
patch_size (`int`, *optional*):
# 在补丁嵌入层中使用的补丁大小。
The size of the patches to use in the patch embedding layer.
# 前馈中的激活函数,默认为"geglu"
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
# 训练期间使用的扩散步骤数
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states. During inference, you can denoise for up to but not more steps than
`num_embeds_ada_norm`.
# 使用的归一化类型,选项为"layer_norm"或"ada_layer_norm"
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
# 是否在归一化层中使用逐元素仿射,默认为True
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
# 归一化层中使用的epsilon值,默认为1e-5
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
# 标注嵌入的通道数
caption_channels (`int`, *optional*):
The number of channels in the caption embeddings.
# 视频类数据中的帧数
video_length (`int`, *optional*):
The number of frames in the video-like data.
"""
# 注册配置的装饰器
@register_to_config
# 初始化方法,用于设置模型的参数
def __init__(
# 注意力头的数量,默认为16
num_attention_heads: int = 16,
# 每个注意力头的维度,默认为88
attention_head_dim: int = 88,
# 输入通道数,默认为None
in_channels: Optional[int] = None,
# 输出通道数,默认为None
out_channels: Optional[int] = None,
# 网络层数,默认为1
num_layers: int = 1,
# dropout比率,默认为0.0
dropout: float = 0.0,
# 跨注意力的维度,默认为None
cross_attention_dim: Optional[int] = None,
# 是否使用注意力偏置,默认为False
attention_bias: bool = False,
# 样本大小,默认为64
sample_size: int = 64,
# 每个patch的大小,默认为None
patch_size: Optional[int] = None,
# 激活函数,默认为"geglu"
activation_fn: str = "geglu",
# 自适应归一化的嵌入数量,默认为None
num_embeds_ada_norm: Optional[int] = None,
# 归一化类型,默认为"layer_norm"
norm_type: str = "layer_norm",
# 归一化是否进行逐元素仿射,默认为True
norm_elementwise_affine: bool = True,
# 归一化的epsilon值,默认为1e-5
norm_eps: float = 1e-5,
# caption的通道数,默认为None
caption_channels: int = None,
# 视频长度,默认为16
video_length: int = 16,
# 设置梯度检查点的函数,接收一个模块和一个布尔值
def _set_gradient_checkpointing(self, module, value=False):
# 将梯度检查点设置为给定的布尔值
self.gradient_checkpointing = value
# 前向传播方法,定义模型的前向计算
def forward(
# 输入的隐藏状态,类型为torch.Tensor
hidden_states: torch.Tensor,
# 可选的时间步长,类型为torch.LongTensor
timestep: Optional[torch.LongTensor] = None,
# 可选的编码器隐藏状态,类型为torch.Tensor
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可选的编码器注意力掩码,类型为torch.Tensor
encoder_attention_mask: Optional[torch.Tensor] = None,
# 是否启用时间注意力,默认为True
enable_temporal_attentions: bool = True,
# 是否返回字典形式的输出,默认为True
return_dict: bool = True,
# 版权声明,表明此代码的版权归 2024 Alpha-VLLM 作者及 HuggingFace 团队所有
#
# 根据 Apache 2.0 许可证("许可证")进行许可;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,软件在许可证下分发是按"原样"基础,
# 不提供任何形式的明示或暗示的担保或条件。
# 请参阅许可证以获取有关权限和
# 限制的具体信息。
from typing import Any, Dict, Optional # 导入类型提示相关的模块
import torch # 导入 PyTorch 库
import torch.nn as nn # 导入 PyTorch 的神经网络模块
# 从配置和工具模块导入必要的类和函数
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import LuminaFeedForward # 导入自定义的前馈网络
from ..attention_processor import Attention, LuminaAttnProcessor2_0 # 导入注意力处理器
from ..embeddings import (
LuminaCombinedTimestepCaptionEmbedding, # 导入组合时间步长的嵌入
LuminaPatchEmbed, # 导入补丁嵌入
)
from ..modeling_outputs import Transformer2DModelOutput # 导入模型输出类
from ..modeling_utils import ModelMixin # 导入模型混合类
from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm # 导入不同的归一化方法
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器,禁用 pylint 的名称警告
class LuminaNextDiTBlock(nn.Module): # 定义一个名为 LuminaNextDiTBlock 的类,继承自 nn.Module
"""
LuminaNextDiTBlock 用于 LuminaNextDiT2DModel。
参数:
dim (`int`): 输入特征的嵌入维度。
num_attention_heads (`int`): 注意力头的数量。
num_kv_heads (`int`):
键和值特征中的注意力头数量(如果使用 GQA),
或设置为 None 以与查询相同。
multiple_of (`int`): 前馈网络层的倍数。
ffn_dim_multiplier (`float`): 前馈网络层维度的乘数因子。
norm_eps (`float`): 归一化层的 epsilon 值。
qk_norm (`bool`): 查询和键的归一化。
cross_attention_dim (`int`): 输入文本提示的跨注意力嵌入维度。
norm_elementwise_affine (`bool`, *可选*, 默认为 True),
"""
def __init__( # 初始化方法
self,
dim: int, # 输入特征的维度
num_attention_heads: int, # 注意力头的数量
num_kv_heads: int, # 键和值特征的头数量
multiple_of: int, # 前馈网络层的倍数
ffn_dim_multiplier: float, # 前馈网络维度的乘数
norm_eps: float, # 归一化的 epsilon 值
qk_norm: bool, # 是否对查询和键进行归一化
cross_attention_dim: int, # 跨注意力嵌入的维度
norm_elementwise_affine: bool = True, # 是否使用逐元素仿射归一化,默认值为 True
) -> None: # 定义方法的返回类型为 None,表示不返回任何值
super().__init__() # 调用父类的构造函数,初始化父类的属性
self.head_dim = dim // num_attention_heads # 计算每个注意力头的维度
self.gate = nn.Parameter(torch.zeros([num_attention_heads])) # 创建一个可学习的参数,初始化为零,大小为注意力头的数量
# Self-attention # 定义自注意力机制
self.attn1 = Attention( # 创建第一个注意力层
query_dim=dim, # 查询的维度
cross_attention_dim=None, # 交叉注意力的维度,此处为 None 表示不使用
dim_head=dim // num_attention_heads, # 每个头的维度
qk_norm="layer_norm_across_heads" if qk_norm else None, # 如果 qk_norm 为真,使用跨头层归一化
heads=num_attention_heads, # 注意力头的数量
kv_heads=num_kv_heads, # 键值对的头数量
eps=1e-5, # 数值稳定性参数
bias=False, # 不使用偏置项
out_bias=False, # 输出层不使用偏置项
processor=LuminaAttnProcessor2_0(), # 使用指定的注意力处理器
)
self.attn1.to_out = nn.Identity() # 输出层使用恒等映射
# Cross-attention # 定义交叉注意力机制
self.attn2 = Attention( # 创建第二个注意力层
query_dim=dim, # 查询的维度
cross_attention_dim=cross_attention_dim, # 交叉注意力的维度
dim_head=dim // num_attention_heads, # 每个头的维度
qk_norm="layer_norm_across_heads" if qk_norm else None, # 如果 qk_norm 为真,使用跨头层归一化
heads=num_attention_heads, # 注意力头的数量
kv_heads=num_kv_heads, # 键值对的头数量
eps=1e-5, # 数值稳定性参数
bias=False, # 不使用偏置项
out_bias=False, # 输出层不使用偏置项
processor=LuminaAttnProcessor2_0(), # 使用指定的注意力处理器
)
self.feed_forward = LuminaFeedForward( # 创建前馈神经网络层
dim=dim, # 输入维度
inner_dim=4 * dim, # 内部维度,通常为输入维度的四倍
multiple_of=multiple_of, # 确保内部维度是某个数字的倍数
ffn_dim_multiplier=ffn_dim_multiplier, # 前馈网络维度的乘数
)
self.norm1 = LuminaRMSNormZero( # 创建第一个 RMS 归一化层
embedding_dim=dim, # 归一化的嵌入维度
norm_eps=norm_eps, # 归一化的 epsilon 参数
norm_elementwise_affine=norm_elementwise_affine, # 是否使用元素级仿射变换
)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) # 创建前馈网络的 RMS 归一化层
self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) # 创建第二个 RMS 归一化层
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) # 创建前馈网络的第二个 RMS 归一化层
self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) # 创建上下文的 RMS 归一化层
def forward( # 定义前向传播方法
self,
hidden_states: torch.Tensor, # 输入的隐藏状态张量
attention_mask: torch.Tensor, # 注意力掩码张量
image_rotary_emb: torch.Tensor, # 图像旋转嵌入张量
encoder_hidden_states: torch.Tensor, # 编码器的隐藏状态张量
encoder_mask: torch.Tensor, # 编码器的掩码张量
temb: torch.Tensor, # 位置编码或时间编码张量
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 可选的交叉注意力参数字典
):
"""
执行 LuminaNextDiTBlock 的前向传递。
参数:
hidden_states (`torch.Tensor`): LuminaNextDiTBlock 的输入隐藏状态。
attention_mask (`torch.Tensor): 对应隐藏状态的注意力掩码。
image_rotary_emb (`torch.Tensor`): 预计算的余弦和正弦频率。
encoder_hidden_states: (`torch.Tensor`): 通过 Gemma 编码器处理的文本提示的隐藏状态。
encoder_mask (`torch.Tensor`): 文本提示的隐藏状态注意力掩码。
temb (`torch.Tensor`): 带有文本提示嵌入的时间步嵌入。
cross_attention_kwargs (`Dict[str, Any]`): 交叉注意力的参数。
"""
# 保存输入的隐藏状态,以便后续使用
residual = hidden_states
# 自注意力
# 对隐藏状态进行归一化,并计算门控机制的输出
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
# 计算自注意力的输出
self_attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
query_rotary_emb=image_rotary_emb,
key_rotary_emb=image_rotary_emb,
**cross_attention_kwargs,
)
# 交叉注意力
# 对编码器的隐藏状态进行归一化
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
# 计算交叉注意力的输出
cross_attn_output = self.attn2(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=encoder_mask,
query_rotary_emb=image_rotary_emb,
key_rotary_emb=None,
**cross_attention_kwargs,
)
# 将交叉注意力的输出进行缩放
cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1)
# 将自注意力和交叉注意力的输出混合
mixed_attn_output = self_attn_output + cross_attn_output
# 将混合输出展平,以便后续处理
mixed_attn_output = mixed_attn_output.flatten(-2)
# 线性投影
# 通过线性层处理混合输出,得到新的隐藏状态
hidden_states = self.attn2.to_out[0](mixed_attn_output)
# 更新隐藏状态,加入残差连接和门控机制
hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states)
# 通过前馈网络计算输出
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
# 更新隐藏状态,加入前馈网络输出和门控机制
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
# 返回最终的隐藏状态
return hidden_states
# 定义一个名为 LuminaNextDiT2DModel 的类,继承自 ModelMixin 和 ConfigMixin
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
"""
LuminaNextDiT: 使用 Transformer 主干的扩散模型。
继承 ModelMixin 和 ConfigMixin 以兼容 diffusers 的 StableDiffusionPipeline 采样器。
参数:
sample_size (`int`): 潜在图像的宽度。此值在训练期间固定,因为
它用于学习位置嵌入的数量。
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
图像中每个补丁的大小。此参数定义输入到模型中的补丁的分辨率。
in_channels (`int`, *optional*, defaults to 4):
模型的输入通道数量。通常,这与输入图像的通道数量匹配。
hidden_size (`int`, *optional*, defaults to 4096):
模型隐藏层的维度。此参数决定了模型隐藏表示的宽度。
num_layers (`int`, *optional*, default to 32):
模型中的层数。此值定义了神经网络的深度。
num_attention_heads (`int`, *optional*, defaults to 32):
每个注意力层中的注意力头数量。此参数指定使用多少个独立的注意力机制。
num_kv_heads (`int`, *optional*, defaults to 8):
注意力机制中的键值头数量,如果与注意力头数量不同。如果为 None,则默认值为 num_attention_heads。
multiple_of (`int`, *optional*, defaults to 256):
隐藏大小应该是一个倍数的因子。这可以帮助优化某些硬件
配置。
ffn_dim_multiplier (`float`, *optional*):
前馈网络维度的乘数。如果为 None,则使用基于
模型配置的默认值。
norm_eps (`float`, *optional*, defaults to 1e-5):
添加到归一化层的分母中的一个小值,用于数值稳定性。
learn_sigma (`bool`, *optional*, defaults to True):
模型是否应该学习 sigma 参数,该参数可能与预测中的不确定性或方差相关。
qk_norm (`bool`, *optional*, defaults to True):
指示注意力机制中的查询和键是否应该被归一化。
cross_attention_dim (`int`, *optional*, defaults to 2048):
文本嵌入的维度。此参数定义了用于模型的文本表示的大小。
scaling_factor (`float`, *optional*, defaults to 1.0):
应用于模型某些参数或层的缩放因子。此参数可用于调整模型操作的整体规模。
"""
# 注册到配置
@register_to_config
def __init__(
# 样本大小,默认值为128
self,
sample_size: int = 128,
# 补丁大小,默认为2,表示图像切割块的大小
patch_size: Optional[int] = 2,
# 输入通道数,默认为4,表示输入数据的特征通道
in_channels: Optional[int] = 4,
# 隐藏层大小,默认为2304
hidden_size: Optional[int] = 2304,
# 网络层数,默认为32
num_layers: Optional[int] = 32,
# 注意力头数量,默认为32
num_attention_heads: Optional[int] = 32,
# KV头的数量,默认为None
num_kv_heads: Optional[int] = None,
# 数量的倍数,默认为256
multiple_of: Optional[int] = 256,
# FFN维度乘数,默认为None
ffn_dim_multiplier: Optional[float] = None,
# 归一化的epsilon值,默认为1e-5
norm_eps: Optional[float] = 1e-5,
# 是否学习方差,默认为True
learn_sigma: Optional[bool] = True,
# 是否进行QK归一化,默认为True
qk_norm: Optional[bool] = True,
# 交叉注意力维度,默认为2048
cross_attention_dim: Optional[int] = 2048,
# 缩放因子,默认为1.0
scaling_factor: Optional[float] = 1.0,
) -> None:
# 调用父类初始化方法
super().__init__()
# 设置样本大小属性
self.sample_size = sample_size
# 设置补丁大小属性
self.patch_size = patch_size
# 设置输入通道数属性
self.in_channels = in_channels
# 根据是否学习方差设置输出通道数
self.out_channels = in_channels * 2 if learn_sigma else in_channels
# 设置隐藏层大小属性
self.hidden_size = hidden_size
# 设置注意力头数量属性
self.num_attention_heads = num_attention_heads
# 计算并设置每个注意力头的维度
self.head_dim = hidden_size // num_attention_heads
# 设置缩放因子属性
self.scaling_factor = scaling_factor
# 创建补丁嵌入层,并初始化其参数
self.patch_embedder = LuminaPatchEmbed(
patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True
)
# 创建一个可学习的填充标记,初始化为空张量
self.pad_token = nn.Parameter(torch.empty(hidden_size))
# 创建时间和标题的组合嵌入层
self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding(
hidden_size=min(hidden_size, 1024), cross_attention_dim=cross_attention_dim
)
# 创建包含多个层的模块列表
self.layers = nn.ModuleList(
[
# 在模块列表中添加多个下一代块
LuminaNextDiTBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
cross_attention_dim,
)
for _ in range(num_layers) # 根据层数循环
]
)
# 创建层归一化输出层
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
bias=True,
out_dim=patch_size * patch_size * self.out_channels,
)
# 注释掉的最终层的初始化(若需要可取消注释)
# self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)
# 确保隐藏层大小与注意力头数量的关系,保证为4的倍数
assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
# 前向传播函数定义
def forward(
# 隐藏状态的输入张量
self,
hidden_states: torch.Tensor,
# 时间步的输入张量
timestep: torch.Tensor,
# 编码器的隐藏状态张量
encoder_hidden_states: torch.Tensor,
# 编码器的掩码张量
encoder_mask: torch.Tensor,
# 图像的旋转嵌入张量
image_rotary_emb: torch.Tensor,
# 交叉注意力的其他参数,默认为None
cross_attention_kwargs: Dict[str, Any] = None,
# 是否返回字典形式的结果,默认为True
return_dict=True,
# LuminaNextDiT 的前向传播函数
) -> torch.Tensor:
"""
前向传播的 LuminaNextDiT 模型。
参数:
hidden_states (torch.Tensor): 输入张量,形状为 (N, C, H, W)。
timestep (torch.Tensor): 扩散时间步的张量,形状为 (N,).
encoder_hidden_states (torch.Tensor): 描述特征的张量,形状为 (N, D)。
encoder_mask (torch.Tensor): 描述特征掩码的张量,形状为 (N, L)。
"""
# 通过补丁嵌入器处理隐藏状态,获取掩码、图像大小和图像旋转嵌入
hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
# 将图像旋转嵌入转移到与隐藏状态相同的设备上
image_rotary_emb = image_rotary_emb.to(hidden_states.device)
# 生成时间嵌入,结合时间步和编码器隐藏状态
temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
# 将编码器掩码转换为布尔值
encoder_mask = encoder_mask.bool()
# 对每一层进行遍历,更新隐藏状态
for layer in self.layers:
hidden_states = layer(
hidden_states,
mask,
image_rotary_emb,
encoder_hidden_states,
encoder_mask,
temb=temb,
cross_attention_kwargs=cross_attention_kwargs,
)
# 对隐藏状态进行归一化处理
hidden_states = self.norm_out(hidden_states, temb)
# 反补丁操作
height_tokens = width_tokens = self.patch_size # 获取补丁大小
height, width = img_size[0] # 从图像大小中提取高度和宽度
batch_size = hidden_states.size(0) # 获取批次大小
sequence_length = (height // height_tokens) * (width // width_tokens) # 计算序列长度
# 调整隐藏状态的形状,以适应输出要求
hidden_states = hidden_states[:, :sequence_length].view(
batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
)
# 调整维度以获得最终输出
output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
# 如果不需要返回字典,则返回输出元组
if not return_dict:
return (output,)
# 返回 Transformer2DModelOutput 的结果
return Transformer2DModelOutput(sample=output)
# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,软件
# 按“原样”提供,没有任何形式的明示或暗示的担保或条件。
# 请参阅许可证以了解有关权限和
# 限制的具体条款。
from typing import Any, Dict, Optional, Union # 导入类型提示相关的模块
import torch # 导入 PyTorch 库
from torch import nn # 从 PyTorch 导入神经网络模块
from ...configuration_utils import ConfigMixin, register_to_config # 导入配置相关的混合类和注册函数
from ...utils import is_torch_version, logging # 导入工具函数:检查 PyTorch 版本和日志记录
from ..attention import BasicTransformerBlock # 导入基础 Transformer 块
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0 # 导入注意力相关的处理器
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection # 导入嵌入相关的模块
from ..modeling_outputs import Transformer2DModelOutput # 导入模型输出相关的类
from ..modeling_utils import ModelMixin # 导入模型混合类
from ..normalization import AdaLayerNormSingle # 导入自适应层归一化类
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器;pylint 禁用命名检查
class PixArtTransformer2DModel(ModelMixin, ConfigMixin): # 定义 PixArt 2D Transformer 模型类,继承自 ModelMixin 和 ConfigMixin
r""" # 文档字符串:描述模型及其来源
A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
https://arxiv.org/abs/2403.04692).
"""
_supports_gradient_checkpointing = True # 设置支持梯度检查点
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] # 指定不进行分割的模块
@register_to_config # 使用装饰器将初始化函数注册到配置中
def __init__( # 定义初始化函数
self,
num_attention_heads: int = 16, # 注意力头的数量,默认为 16
attention_head_dim: int = 72, # 每个注意力头的维度,默认为 72
in_channels: int = 4, # 输入通道数,默认为 4
out_channels: Optional[int] = 8, # 输出通道数,默认为 8,可选
num_layers: int = 28, # 层数,默认为 28
dropout: float = 0.0, # dropout 比例,默认为 0.0
norm_num_groups: int = 32, # 归一化的组数,默认为 32
cross_attention_dim: Optional[int] = 1152, # 交叉注意力的维度,默认为 1152,可选
attention_bias: bool = True, # 是否使用注意力偏置,默认为 True
sample_size: int = 128, # 样本尺寸,默认为 128
patch_size: int = 2, # 每个补丁的尺寸,默认为 2
activation_fn: str = "gelu-approximate", # 激活函数类型,默认为近似 GELU
num_embeds_ada_norm: Optional[int] = 1000, # 自适应归一化的嵌入数量,默认为 1000,可选
upcast_attention: bool = False, # 是否提高注意力精度,默认为 False
norm_type: str = "ada_norm_single", # 归一化类型,默认为单一自适应归一化
norm_elementwise_affine: bool = False, # 是否使用逐元素仿射变换,默认为 False
norm_eps: float = 1e-6, # 归一化的 epsilon 值,默认为 1e-6
interpolation_scale: Optional[int] = None, # 插值尺度,可选
use_additional_conditions: Optional[bool] = None, # 是否使用额外条件,可选
caption_channels: Optional[int] = None, # 说明通道数,可选
attention_type: Optional[str] = "default", # 注意力类型,默认为默认类型
):
# 初始化函数参数设置
...
def _set_gradient_checkpointing(self, module, value=False): # 定义设置梯度检查点的方法
if hasattr(module, "gradient_checkpointing"): # 检查模块是否具有梯度检查点属性
module.gradient_checkpointing = value # 设置梯度检查点的值
@property # 定义一个属性
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 复制的属性
# 定义一个方法,返回模型中所有注意力处理器的字典,键为权重名称
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# 创建一个空字典,用于存储注意力处理器
processors = {}
# 定义一个递归函数,用于添加处理器到字典中
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 如果模块具有获取处理器的方法,则将其添加到字典中
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
# 遍历子模块,递归调用该函数以添加处理器
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回更新后的处理器字典
return processors
# 遍历当前模块的所有子模块,调用递归函数以填充处理器字典
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
# 返回所有注意力处理器的字典
return processors
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 复制
# 定义一个方法,用于设置计算注意力的处理器
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
# 获取当前注意力处理器的数量
count = len(self.attn_processors.keys())
# 检查传入的处理器字典的长度是否与注意力层的数量匹配
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
# 定义一个递归函数,用于设置处理器
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 如果模块具有设置处理器的方法,则进行设置
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor) # 设置单一处理器
else:
module.set_processor(processor.pop(f"{name}.processor")) # 从字典中移除并设置处理器
# 遍历子模块,递归调用以设置处理器
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍历当前模块的所有子模块,调用递归函数以设置处理器
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections 复制
# 定义融合 QKV 投影的函数
def fuse_qkv_projections(self):
# 启用融合的 QKV 投影,对自注意力模块进行融合查询、键、值矩阵
# 对交叉注意力模块则仅融合键和值投影矩阵
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
# 初始化原始注意力处理器为 None
self.original_attn_processors = None
# 遍历所有注意力处理器
for _, attn_processor in self.attn_processors.items():
# 检查处理器类名中是否包含 "Added"
if "Added" in str(attn_processor.__class__.__name__):
# 如果存在,则抛出错误,说明不支持此融合操作
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# 保存当前的注意力处理器
self.original_attn_processors = self.attn_processors
# 遍历模型中的所有模块
for module in self.modules():
# 如果模块是 Attention 类型
if isinstance(module, Attention):
# 执行投影融合
module.fuse_projections(fuse=True)
# 设置新的融合注意力处理器
self.set_attn_processor(FusedAttnProcessor2_0())
# 从 UNet2DConditionModel 中复制的函数,用于取消融合 QKV 投影
def unfuse_qkv_projections(self):
# 禁用已启用的融合 QKV 投影
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
# 检查原始注意力处理器是否存在
if self.original_attn_processors is not None:
# 恢复到原始注意力处理器
self.set_attn_processor(self.original_attn_processors)
# 定义前向传播函数
def forward(
# 输入隐藏状态的张量
hidden_states: torch.Tensor,
# 编码器隐藏状态(可选)
encoder_hidden_states: Optional[torch.Tensor] = None,
# 时间步长(可选)
timestep: Optional[torch.LongTensor] = None,
# 添加的条件关键字参数(字典类型,可选)
added_cond_kwargs: Dict[str, torch.Tensor] = None,
# 交叉注意力关键字参数(字典类型,可选)
cross_attention_kwargs: Dict[str, Any] = None,
# 注意力掩码(可选)
attention_mask: Optional[torch.Tensor] = None,
# 编码器注意力掩码(可选)
encoder_attention_mask: Optional[torch.Tensor] = None,
# 是否返回字典(默认值为 True)
return_dict: bool = True,
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 导入字典、可选值和联合类型的定义
from typing import Dict, Optional, Union
# 导入 PyTorch 及其功能模块
import torch
import torch.nn.functional as F
# 从 PyTorch 导入神经网络模块
from torch import nn
# 导入配置和注册功能的相关类
from ...configuration_utils import ConfigMixin, register_to_config
# 导入 PeftAdapter 和 UNet2DConditionLoader 的相关类
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
# 导入基础输出工具类
from ...utils import BaseOutput
# 导入基本变换器块
from ..attention import BasicTransformerBlock
# 导入注意力处理器的相关组件
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
# 导入时间步嵌入和时间步类
from ..embeddings import TimestepEmbedding, Timesteps
# 导入模型混合工具类
from ..modeling_utils import ModelMixin
# 定义 PriorTransformerOutput 数据类,继承自 BaseOutput
@dataclass
class PriorTransformerOutput(BaseOutput):
"""
[`PriorTransformer`] 的输出。
Args:
predicted_image_embedding (`torch.Tensor` 的形状为 `(batch_size, embedding_dim)`):
基于 CLIP 文本嵌入输入的预测 CLIP 图像嵌入。
"""
# 定义预测的图像嵌入属性
predicted_image_embedding: torch.Tensor
# 定义 PriorTransformer 类,继承多个混合类
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
"""
一种 Prior Transformer 模型。
# 参数说明部分
Parameters:
# 用于多头注意力的头数量,默认为 32
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
# 每个头的通道数量,默认为 64
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
# Transformer 块的层数,默认为 20
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
# 模型输入 `hidden_states` 的维度,默认为 768
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
# 模型输入 `hidden_states` 的嵌入数量,默认为 77
num_embeddings (`int`, *optional*, defaults to 77):
The number of embeddings of the model input `hidden_states`
# 附加令牌的数量,默认为 4,追加到投影的 `hidden_states`
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
additional_embeddings`.
# 用于 dropout 的概率,默认为 0.0
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
# 创建时间步嵌入时使用的激活函数,默认为 'silu'
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
The activation function to use to create timestep embeddings.
# 在传递给 Transformer 块之前应用的归一化层,默认为 None
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
passing to Transformer blocks. Set it to `None` if normalization is not needed.
# 输入 `proj_embedding` 上应用的归一化层,默认为 None
embedding_proj_norm_type (`str`, *optional*, defaults to None):
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
needed.
# 输入 `encoder_hidden_states` 上应用的投影层,默认为 `linear`
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
`encoder_hidden_states` is `None`.
# 条件模型的附加嵌入类型,默认为 `prd`
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
product between the text embedding and image embedding as proposed in the unclip paper
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
# 时间步嵌入的维度,默认为 None,如果为 None,则设置为 `num_attention_heads * attention_head_dim`
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
If None, will be set to `num_attention_heads * attention_head_dim`
# `proj_embedding` 的维度,默认为 None,如果为 None,则设置为 `embedding_dim`
embedding_proj_dim (`int`, *optional*, default to None):
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
# 输出的维度,默认为 None,如果为 None,则设置为 `embedding_dim`
clip_embed_dim (`int`, *optional*, default to None):
The dimension of the output. If None, will be set to `embedding_dim`.
"""
# 注册到配置中
@register_to_config
# 初始化类的构造函数,设置默认参数
def __init__(
# 注意力头的数量,默认值为32
self,
num_attention_heads: int = 32,
# 每个注意力头的维度,默认值为64
attention_head_dim: int = 64,
# 层的数量,默认值为20
num_layers: int = 20,
# 嵌入的维度,默认值为768
embedding_dim: int = 768,
# 嵌入的数量,默认值为77
num_embeddings=77,
# 额外嵌入的数量,默认值为4
additional_embeddings=4,
# dropout的比率,默认值为0.0
dropout: float = 0.0,
# 时间嵌入激活函数的类型,默认值为"silu"
time_embed_act_fn: str = "silu",
# 输入归一化类型,默认为None
norm_in_type: Optional[str] = None, # layer
# 嵌入投影归一化类型,默认为None
embedding_proj_norm_type: Optional[str] = None, # layer
# 编码器隐藏投影类型,默认值为"linear"
encoder_hid_proj_type: Optional[str] = "linear", # linear
# 添加的嵌入类型,默认值为"prd"
added_emb_type: Optional[str] = "prd", # prd
# 时间嵌入维度,默认为None
time_embed_dim: Optional[int] = None,
# 嵌入投影维度,默认为None
embedding_proj_dim: Optional[int] = None,
# 裁剪嵌入维度,默认为None
clip_embed_dim: Optional[int] = None,
# 定义一个属性,获取注意力处理器
@property
# 从diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors复制而来
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
返回值:
`dict`类型的注意力处理器:一个字典,包含模型中使用的所有注意力处理器,并按其权重名称索引。
"""
# 创建一个空字典用于存储处理器
processors = {}
# 定义递归函数,添加处理器到字典
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 如果模块有获取处理器的方法,添加到字典中
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
# 遍历模块的所有子模块,递归调用自身
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回更新后的处理器字典
return processors
# 遍历当前类的所有子模块,调用递归函数添加处理器
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
# 返回最终的处理器字典
return processors
# 从diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor复制而来
# 设置用于计算注意力的处理器
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
设置用于计算注意力的处理器。
参数:
processor(`dict` 或 `AttentionProcessor`):
实例化的处理器类或处理器类的字典,将作为 **所有** `Attention` 层的处理器。
如果 `processor` 是一个字典,键需要定义相应的交叉注意力处理器的路径。建议在设置可训练注意力处理器时使用此方法。
"""
# 计算当前注意力处理器的数量
count = len(self.attn_processors.keys())
# 如果传入的是字典且其长度与注意力层数量不匹配,抛出错误
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"传入了处理器的字典,但处理器的数量 {len(processor)} 与注意力层的数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
)
# 定义递归设置注意力处理器的函数
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 如果模块具有 set_processor 方法,则设置处理器
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
# 遍历子模块并递归调用
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍历所有子模块,调用递归函数设置处理器
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel 复制的设置默认注意力处理器的方法
def set_default_attn_processor(self):
"""
禁用自定义注意力处理器并设置默认的注意力实现。
"""
# 如果所有处理器都是添加的 KV 注意力处理器,则设置为添加的 KV 处理器
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
# 如果所有处理器都是交叉注意力处理器,则设置为普通的注意力处理器
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"当注意力处理器的类型为 {next(iter(self.attn_processors.values()))} 时,无法调用 `set_default_attn_processor`"
)
# 调用设置处理器的方法
self.set_attn_processor(processor)
# 前向传播方法定义
def forward(
self,
hidden_states,
timestep: Union[torch.Tensor, float, int],
proj_embedding: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
return_dict: bool = True,
):
# 处理传入的潜在变量
def post_process_latents(self, prior_latents):
# 将潜在变量进行标准化处理
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
return prior_latents
# 版权声明,注明版权归属
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# 根据 Apache License 2.0 授权协议进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 只有在遵守许可证的情况下,您才能使用此文件
# you may not use this file except in compliance with the License.
# 您可以在以下网址获取许可证副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,软件按“原样”分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的保证或条件,明示或暗示
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 请参见许可证以获取有关权限和限制的具体信息
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入类型注解
from typing import Any, Dict, Optional, Union
# 导入 NumPy 库
import numpy as np
# 导入 PyTorch 库
import torch
# 导入 PyTorch 的神经网络模块
import torch.nn as nn
# 导入 PyTorch 的检查点工具
import torch.utils.checkpoint
# 从配置工具导入相关类
from ...configuration_utils import ConfigMixin, register_to_config
# 从注意力模块导入前馈网络
from ...models.attention import FeedForward
# 从注意力处理器模块导入多个类
from ...models.attention_processor import (
Attention,
AttentionProcessor,
StableAudioAttnProcessor2_0,
)
# 从建模工具导入模型混合类
from ...models.modeling_utils import ModelMixin
# 从变换器模型导入输出类
from ...models.transformers.transformer_2d import Transformer2DModelOutput
# 导入实用工具
from ...utils import is_torch_version, logging
# 导入可能允许图形中的工具函数
from ...utils.torch_utils import maybe_allow_in_graph
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableAudioGaussianFourierProjection(nn.Module):
"""用于噪声级别的高斯傅里叶嵌入。"""
# 从 diffusers.models.embeddings.GaussianFourierProjection.__init__ 复制的内容
def __init__(
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
):
super().__init__() # 调用父类的构造函数
# 初始化权重为随机值,且不需要计算梯度
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.log = log # 是否对输入取对数的标志
self.flip_sin_to_cos = flip_sin_to_cos # 是否翻转正弦和余弦的顺序
if set_W_to_weight: # 如果设置将 W 赋值给权重
# 之后将删除此行
del self.weight # 删除原有权重
# 初始化 W 为随机值,并不计算梯度
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.weight = self.W # 将 W 赋值给权重
del self.W # 删除 W
def forward(self, x):
# 如果 log 为 True,则对输入进行对数变换
if self.log:
x = torch.log(x)
# 计算投影,使用 2π 乘以输入和权重的外积
x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :]
if self.flip_sin_to_cos: # 如果翻转正弦和余弦
# 连接余弦和正弦,形成输出
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
else:
# 连接正弦和余弦,形成输出
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out # 返回输出
@maybe_allow_in_graph # 可能允许在计算图中使用
class StableAudioDiTBlock(nn.Module):
r"""
用于稳定音频模型的变换器块 (https://github.com/Stability-AI/stable-audio-tools)。允许跳跃连接和 QKNorm
# 参数说明
Parameters:
dim (`int`): 输入和输出的通道数。
num_attention_heads (`int`): 查询状态所使用的头数。
num_key_value_attention_heads (`int`): 键和值状态所使用的头数。
attention_head_dim (`int`): 每个头中的通道数。
dropout (`float`, *optional*, defaults to 0.0): 使用的丢弃概率。
cross_attention_dim (`int`, *optional*): 跨注意力的 encoder_hidden_states 向量的大小。
upcast_attention (`bool`, *optional*):
是否将注意力计算上升到 float32。这对混合精度训练很有用。
"""
# 初始化函数
def __init__(
self,
dim: int, # 输入和输出的通道数
num_attention_heads: int, # 查询状态的头数
num_key_value_attention_heads: int, # 键和值状态的头数
attention_head_dim: int, # 每个头的通道数
dropout=0.0, # 丢弃概率,默认为0
cross_attention_dim: Optional[int] = None, # 跨注意力的维度,可选
upcast_attention: bool = False, # 是否上升到 float32,默认为 False
norm_eps: float = 1e-5, # 归一化层的小常数
ff_inner_dim: Optional[int] = None, # 前馈层内部维度,可选
):
super().__init__() # 调用父类构造函数
# 定义三个模块。每个模块都有自己的归一化层。
# 1. 自注意力层
self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps) # 自注意力的归一化层
self.attn1 = Attention( # 自注意力模块
query_dim=dim, # 查询维度
heads=num_attention_heads, # 头数
dim_head=attention_head_dim, # 每个头的维度
dropout=dropout, # 丢弃概率
bias=False, # 不使用偏置
upcast_attention=upcast_attention, # 是否上升到 float32
out_bias=False, # 不使用输出偏置
processor=StableAudioAttnProcessor2_0(), # 使用的处理器
)
# 2. 跨注意力层
self.norm2 = nn.LayerNorm(dim, norm_eps, True) # 跨注意力的归一化层
self.attn2 = Attention( # 跨注意力模块
query_dim=dim, # 查询维度
cross_attention_dim=cross_attention_dim, # 跨注意力维度
heads=num_attention_heads, # 头数
dim_head=attention_head_dim, # 每个头的维度
kv_heads=num_key_value_attention_heads, # 键和值的头数
dropout=dropout, # 丢弃概率
bias=False, # 不使用偏置
upcast_attention=upcast_attention, # 是否上升到 float32
out_bias=False, # 不使用输出偏置
processor=StableAudioAttnProcessor2_0(), # 使用的处理器
) # 如果 encoder_hidden_states 为 None,则为自注意力
# 3. 前馈层
self.norm3 = nn.LayerNorm(dim, norm_eps, True) # 前馈层的归一化层
self.ff = FeedForward( # 前馈神经网络模块
dim, # 输入维度
dropout=dropout, # 丢弃概率
activation_fn="swiglu", # 激活函数
final_dropout=False, # 最后是否丢弃
inner_dim=ff_inner_dim, # 内部维度
bias=True, # 使用偏置
)
# 将块大小默认设置为 None
self._chunk_size = None # 块大小
self._chunk_dim = 0 # 块维度
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# 设置块前馈
self._chunk_size = chunk_size # 设置块大小
self._chunk_dim = dim # 设置块维度
# 定义前向传播方法,接收隐藏状态和可选的注意力掩码等参数
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
rotary_embedding: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
# 注意:在后续计算中,归一化总是应用于实际计算之前。
# 0. 自注意力
# 对输入的隐藏状态进行归一化处理
norm_hidden_states = self.norm1(hidden_states)
# 计算自注意力输出
attn_output = self.attn1(
norm_hidden_states,
attention_mask=attention_mask,
rotary_emb=rotary_embedding,
)
# 将自注意力输出与原始隐藏状态相加
hidden_states = attn_output + hidden_states
# 2. 跨注意力
# 对更新后的隐藏状态进行归一化处理
norm_hidden_states = self.norm2(hidden_states)
# 计算跨注意力输出
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
# 将跨注意力输出与更新后的隐藏状态相加
hidden_states = attn_output + hidden_states
# 3. 前馈网络
# 对隐藏状态进行归一化处理
norm_hidden_states = self.norm3(hidden_states)
# 计算前馈网络输出
ff_output = self.ff(norm_hidden_states)
# 将前馈网络输出与当前隐藏状态相加
hidden_states = ff_output + hidden_states
# 返回最终的隐藏状态
return hidden_states
# 定义一个名为 StableAudioDiTModel 的类,继承自 ModelMixin 和 ConfigMixin
class StableAudioDiTModel(ModelMixin, ConfigMixin):
"""
Stable Audio 中引入的扩散变换器模型。
参考文献:https://github.com/Stability-AI/stable-audio-tools
参数:
sample_size ( `int`, *可选*, 默认值为 1024):输入样本的大小。
in_channels (`int`, *可选*, 默认值为 64):输入中的通道数。
num_layers (`int`, *可选*, 默认值为 24):使用的变换器块的层数。
attention_head_dim (`int`, *可选*, 默认值为 64):每个头的通道数。
num_attention_heads (`int`, *可选*, 默认值为 24):用于查询状态的头数。
num_key_value_attention_heads (`int`, *可选*, 默认值为 12):
用于键和值状态的头数。
out_channels (`int`, 默认值为 64):输出通道的数量。
cross_attention_dim ( `int`, *可选*, 默认值为 768):交叉注意力投影的维度。
time_proj_dim ( `int`, *可选*, 默认值为 256):时间步内投影的维度。
global_states_input_dim ( `int`, *可选*, 默认值为 1536):
全局隐藏状态投影的输入维度。
cross_attention_input_dim ( `int`, *可选*, 默认值为 768):
交叉注意力投影的输入维度。
"""
# 支持梯度检查点
_supports_gradient_checkpointing = True
# 注册到配置的构造函数
@register_to_config
def __init__(
# 输入样本的大小,默认为1024
self,
sample_size: int = 1024,
# 输入的通道数,默认为64
in_channels: int = 64,
# 变换器块的层数,默认为24
num_layers: int = 24,
# 每个头的通道数,默认为64
attention_head_dim: int = 64,
# 查询状态的头数,默认为24
num_attention_heads: int = 24,
# 键和值状态的头数,默认为12
num_key_value_attention_heads: int = 12,
# 输出通道的数量,默认为64
out_channels: int = 64,
# 交叉注意力投影的维度,默认为768
cross_attention_dim: int = 768,
# 时间步内投影的维度,默认为256
time_proj_dim: int = 256,
# 全局隐藏状态投影的输入维度,默认为1536
global_states_input_dim: int = 1536,
# 交叉注意力投影的输入维度,默认为768
cross_attention_input_dim: int = 768,
):
# 调用父类的初始化方法
super().__init__()
# 设置样本大小
self.sample_size = sample_size
# 设置输出通道数
self.out_channels = out_channels
# 计算内部维度,等于注意力头数量乘以每个头的维度
self.inner_dim = num_attention_heads * attention_head_dim
# 创建稳定音频高斯傅里叶投影对象,embedding_size 为时间投影维度的一半
self.time_proj = StableAudioGaussianFourierProjection(
embedding_size=time_proj_dim // 2,
flip_sin_to_cos=True, # 是否翻转正弦和余弦
log=False, # 是否使用对数
set_W_to_weight=False, # 是否将 W 设置为权重
)
# 时间步投影的神经网络序列,包含两个线性层和一个激活函数
self.timestep_proj = nn.Sequential(
nn.Linear(time_proj_dim, self.inner_dim, bias=True), # 输入为 time_proj_dim,输出为 inner_dim
nn.SiLU(), # 使用 SiLU 激活函数
nn.Linear(self.inner_dim, self.inner_dim, bias=True), # 再次投影到 inner_dim
)
# 全局状态投影的神经网络序列,包含两个线性层和一个激活函数
self.global_proj = nn.Sequential(
nn.Linear(global_states_input_dim, self.inner_dim, bias=False), # 输入为 global_states_input_dim,输出为 inner_dim
nn.SiLU(), # 使用 SiLU 激活函数
nn.Linear(self.inner_dim, self.inner_dim, bias=False), # 再次投影到 inner_dim
)
# 交叉注意力投影的神经网络序列,包含两个线性层和一个激活函数
self.cross_attention_proj = nn.Sequential(
nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), # 输入为 cross_attention_input_dim,输出为 cross_attention_dim
nn.SiLU(), # 使用 SiLU 激活函数
nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), # 再次投影到 cross_attention_dim
)
# 一维卷积层,用于预处理,卷积核大小为 1,不使用偏置
self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False)
# 输入线性层,将输入通道数投影到 inner_dim,不使用偏置
self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False)
# 创建一个模块列表,包含多个 StableAudioDiTBlock
self.transformer_blocks = nn.ModuleList(
[
StableAudioDiTBlock(
dim=self.inner_dim, # 输入维度为 inner_dim
num_attention_heads=num_attention_heads, # 注意力头数量
num_key_value_attention_heads=num_key_value_attention_heads, # 键值注意力头数量
attention_head_dim=attention_head_dim, # 每个注意力头的维度
cross_attention_dim=cross_attention_dim, # 交叉注意力维度
)
for i in range(num_layers) # 根据层数创建相应数量的块
]
)
# 输出线性层,将 inner_dim 投影到输出通道数,不使用偏置
self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False)
# 一维卷积层,用于后处理,卷积核大小为 1,不使用偏置
self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False)
# 初始化梯度检查点标志,默认为 False
self.gradient_checkpointing = False
@property
# 从 UNet2DConditionModel 复制的属性
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# 创建一个空字典,用于存储注意力处理器
processors = {}
# 定义递归函数,用于添加注意力处理器
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 如果模块有 get_processor 方法,获取处理器并添加到字典
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
# 遍历子模块,递归调用添加处理器函数
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
# 遍历当前对象的所有子模块,调用递归函数
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
# 返回包含所有处理器的字典
return processors
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 复制而来
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
设置用于计算注意力的注意力处理器。
参数:
processor (`dict` of `AttentionProcessor` 或 `AttentionProcessor`):
实例化的处理器类或将作为处理器设置为 **所有** `Attention` 层的处理器类字典。
如果 `processor` 是一个字典,键需要定义对应的交叉注意力处理器的路径。
当设置可训练的注意力处理器时,这一点强烈推荐。
"""
# 获取当前注意力处理器的数量
count = len(self.attn_processors.keys())
# 如果传入的是字典且字典长度与注意力层数量不匹配,则抛出异常
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"传入了处理器字典,但处理器数量 {len(processor)} 与"
f" 注意力层数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
)
# 定义递归设置注意力处理器的函数
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 如果模块有 set_processor 方法,则设置处理器
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
# 如果处理器不是字典,则直接设置
module.set_processor(processor)
else:
# 从字典中弹出对应的处理器并设置
module.set_processor(processor.pop(f"{name}.processor"))
# 遍历模块的子模块,递归调用设置处理器的函数
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍历当前对象的子模块,调用递归设置处理器的函数
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# 从 diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor 复制而来,将 Hunyuan 替换为 StableAudio
def set_default_attn_processor(self):
"""
禁用自定义注意力处理器,并设置默认的注意力实现。
"""
# 调用设置注意力处理器的方法,使用 StableAudioAttnProcessor2_0 实例
self.set_attn_processor(StableAudioAttnProcessor2_0())
# 设置梯度检查点的私有方法
def _set_gradient_checkpointing(self, module, value=False):
# 如果模块有 gradient_checkpointing 属性,则设置其值
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# 前向传播方法定义
def forward(
# 输入的隐藏状态张量
hidden_states: torch.FloatTensor,
# 时间步张量,默认为 None
timestep: torch.LongTensor = None,
# 编码器的隐藏状态张量,默认为 None
encoder_hidden_states: torch.FloatTensor = None,
# 全局隐藏状态张量,默认为 None
global_hidden_states: torch.FloatTensor = None,
# 旋转嵌入张量,默认为 None
rotary_embedding: torch.FloatTensor = None,
# 是否返回字典格式,默认为 True
return_dict: bool = True,
# 注意力掩码,默认为 None
attention_mask: Optional[torch.LongTensor] = None,
# 编码器的注意力掩码,默认为 None
encoder_attention_mask: Optional[torch.LongTensor] = None,
# 版权所有 2024 The HuggingFace Team. 保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,依据许可证分发的软件
# 是按“原样”基础分发的,没有任何形式的保证或条件,
# 无论是明示还是暗示。有关许可证下特定语言的权限和
# 限制,请参阅许可证。
# 导入数学模块以执行数学运算
import math
# 从 typing 导入可选类型和元组
from typing import Optional, Tuple
# 导入 PyTorch 库
import torch
# 从 torch 导入神经网络模块
from torch import nn
# 导入配置工具和注册功能
from ...configuration_utils import ConfigMixin, register_to_config
# 导入注意力处理器
from ..attention_processor import Attention
# 导入获取时间步嵌入的功能
from ..embeddings import get_timestep_embedding
# 导入模型工具的基类
from ..modeling_utils import ModelMixin
class T5FilmDecoder(ModelMixin, ConfigMixin):
r"""
T5 风格的解码器,具有 FiLM 条件。
参数:
input_dims (`int`, *可选*, 默认为 `128`):
输入维度的数量。
targets_length (`int`, *可选*, 默认为 `256`):
目标的长度。
d_model (`int`, *可选*, 默认为 `768`):
输入隐藏状态的大小。
num_layers (`int`, *可选*, 默认为 `12`):
使用的 `DecoderLayer` 数量。
num_heads (`int`, *可选*, 默认为 `12`):
使用的注意力头的数量。
d_kv (`int`, *可选*, 默认为 `64`):
键值投影向量的大小。
d_ff (`int`, *可选*, 默认为 `2048`):
`DecoderLayer` 中间前馈层的维度数量。
dropout_rate (`float`, *可选*, 默认为 `0.1`):
丢弃概率。
"""
# 使用装饰器注册初始化函数到配置
@register_to_config
def __init__(
# 输入维度,默认为128
self,
input_dims: int = 128,
# 目标长度,默认为256
targets_length: int = 256,
# 最大解码噪声时间,默认为2000.0
max_decoder_noise_time: float = 2000.0,
# 隐藏状态的维度,默认为768
d_model: int = 768,
# 解码层的数量,默认为12
num_layers: int = 12,
# 注意力头的数量,默认为12
num_heads: int = 12,
# 键值维度大小,默认为64
d_kv: int = 64,
# 中间前馈层的维度,默认为2048
d_ff: int = 2048,
# 丢弃率,默认为0.1
dropout_rate: float = 0.1,
# 初始化父类
):
super().__init__()
# 创建条件嵌入层,包含两层线性变换和激活函数
self.conditioning_emb = nn.Sequential(
# 第一个线性层,输入维度为 d_model,输出维度为 d_model * 4,不使用偏置
nn.Linear(d_model, d_model * 4, bias=False),
# 使用 SiLU 激活函数
nn.SiLU(),
# 第二个线性层,输入维度为 d_model * 4,输出维度同样为 d_model * 4,不使用偏置
nn.Linear(d_model * 4, d_model * 4, bias=False),
# 使用 SiLU 激活函数
nn.SiLU(),
)
# 创建位置编码嵌入,大小为 (targets_length, d_model)
self.position_encoding = nn.Embedding(targets_length, d_model)
# 禁止位置编码的权重更新
self.position_encoding.weight.requires_grad = False
# 创建连续输入的线性投影层,输入维度为 input_dims,输出维度为 d_model,不使用偏置
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
# 创建 dropout 层,丢弃率为 dropout_rate
self.dropout = nn.Dropout(p=dropout_rate)
# 创建解码器层的模块列表
self.decoders = nn.ModuleList()
# 循环创建 num_layers 个解码器层
for lyr_num in range(num_layers):
# 初始化 FiLM 条件 T5 解码器层
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
# 将解码器层添加到列表中
self.decoders.append(lyr)
# 创建解码器层的归一化层
self.decoder_norm = T5LayerNorm(d_model)
# 创建后续 dropout 层,丢弃率为 dropout_rate
self.post_dropout = nn.Dropout(p=dropout_rate)
# 创建输出层,将 d_model 的输出映射回 input_dims,不使用偏置
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
# 定义编码器-解码器掩码函数
def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor:
# 计算查询和键输入的掩码,进行逐元素相乘
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
# 返回掩码并扩展维度
return mask.unsqueeze(-3)
# 前向传播方法,接受编码及掩码、解码器输入和噪声时间
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
# 获取批次大小和解码器输入的形状
batch, _, _ = decoder_input_tokens.shape
# 确保噪声时间的形状与批次一致
assert decoder_noise_time.shape == (batch,)
# 将 decoder_noise_time 重新缩放到期望的时间范围
time_steps = get_timestep_embedding(
decoder_noise_time * self.config.max_decoder_noise_time,
embedding_dim=self.config.d_model,
max_period=self.config.max_decoder_noise_time,
).to(dtype=self.dtype)
# 使用时间步长生成条件嵌入,并扩展维度
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
# 确保条件嵌入的形状正确
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
# 获取解码器输入的序列长度
seq_length = decoder_input_tokens.shape[1]
# 如果使用相对位置,基于编码和掩码的长度偏移序列
decoder_positions = torch.broadcast_to(
torch.arange(seq_length, device=decoder_input_tokens.device),
(batch, seq_length),
)
# 计算位置编码
position_encodings = self.position_encoding(decoder_positions)
# 对解码器输入进行连续输入投影
inputs = self.continuous_inputs_projection(decoder_input_tokens)
# 将位置编码添加到输入中
inputs += position_encodings
# 应用 dropout 操作
y = self.dropout(inputs)
# 创建解码器掩码,没有填充
decoder_mask = torch.ones(
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
)
# 将编码掩码转换为编码器-解码器掩码
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
# 交叉注意力风格:拼接编码
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
# 对每一层解码器进行循环处理
for lyr in self.decoders:
y = lyr(
y,
conditioning_emb=conditioning_emb,
encoder_hidden_states=encoded,
encoder_attention_mask=encoder_decoder_mask,
)[0]
# 对输出进行归一化
y = self.decoder_norm(y)
# 应用 dropout 后处理
y = self.post_dropout(y)
# 生成最终的频谱输出
spec_out = self.spec_out(y)
# 返回频谱输出
return spec_out
# T5 解码器层的定义
class DecoderLayer(nn.Module):
r"""
T5 decoder layer. # T5解码器层的文档说明
Args: # 参数说明
d_model (`int`): # 输入隐藏状态的大小
Size of the input hidden states. # 输入隐藏状态的大小
d_kv (`int`): # 键值投影向量的大小
Size of the key-value projection vectors. # 键值投影向量的大小
num_heads (`int`): # 注意力头的数量
Number of attention heads. # 注意力头的数量
d_ff (`int`): # 中间前馈层的大小
Size of the intermediate feed-forward layer. # 中间前馈层的大小
dropout_rate (`float`): # 丢弃概率
Dropout probability. # 丢弃概率
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`): # 数值稳定性的小值
A small value used for numerical stability to avoid dividing by zero. # 数值稳定性的小值
"""
# 初始化方法,定义各个参数
def __init__(
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
):
super().__init__() # 调用父类构造函数
self.layer = nn.ModuleList() # 初始化模块列表以存储层
# 条件自注意力:第 0 层
self.layer.append(
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) # 添加条件自注意力层
)
# 交叉注意力:第 1 层
self.layer.append(
T5LayerCrossAttention(
d_model=d_model, # 输入隐藏状态的大小
d_kv=d_kv, # 键值投影向量的大小
num_heads=num_heads, # 注意力头的数量
dropout_rate=dropout_rate, # 丢弃概率
layer_norm_epsilon=layer_norm_epsilon, # 数值稳定性的小值
)
)
# Film Cond MLP + 丢弃:最后一层
self.layer.append(
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) # 添加条件前馈层
)
# 前向传播方法
def forward(
self,
hidden_states: torch.Tensor, # 输入隐藏状态张量
conditioning_emb: Optional[torch.Tensor] = None, # 条件嵌入(可选)
attention_mask: Optional[torch.Tensor] = None, # 注意力掩码(可选)
encoder_hidden_states: Optional[torch.Tensor] = None, # 编码器隐藏状态(可选)
encoder_attention_mask: Optional[torch.Tensor] = None, # 编码器注意力掩码(可选)
encoder_decoder_position_bias=None, # 编码器-解码器位置偏置
) -> Tuple[torch.Tensor]: # 返回张量的元组
hidden_states = self.layer[0]( # 通过第一层处理输入隐藏状态
hidden_states,
conditioning_emb=conditioning_emb, # 使用条件嵌入
attention_mask=attention_mask, # 使用注意力掩码
)
# 如果存在编码器隐藏状态
if encoder_hidden_states is not None:
# 扩展编码器注意力掩码
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
encoder_hidden_states.dtype # 转换为编码器隐藏状态的数据类型
)
hidden_states = self.layer[1]( # 通过第二层处理隐藏状态
hidden_states,
key_value_states=encoder_hidden_states, # 使用编码器隐藏状态作为键值
attention_mask=encoder_extended_attention_mask, # 使用扩展的注意力掩码
)
# 应用 Film 条件前馈层
hidden_states = self.layer[-1](hidden_states, conditioning_emb) # 通过最后一层处理隐藏状态,使用条件嵌入
return (hidden_states,) # 返回处理后的隐藏状态元组
# T5样式的自注意力层,带条件
class T5LayerSelfAttentionCond(nn.Module):
r"""
T5 style self-attention layer with conditioning. # T5样式的自注意力层,带条件说明
# 函数参数说明
Args:
d_model (`int`): # 输入隐藏状态的大小
Size of the input hidden states.
d_kv (`int`): # 键值投影向量的大小
Size of the key-value projection vectors.
num_heads (`int`): # 注意力头的数量
Number of attention heads.
dropout_rate (`float`): # 丢弃概率
Dropout probability.
"""
# 初始化方法,设置类的基本参数
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
super().__init__() # 调用父类构造函数
# 创建层归一化层,输入大小为 d_model
self.layer_norm = T5LayerNorm(d_model)
# 创建 FiLM 层,输入特征为 d_model * 4,输出特征为 d_model
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
# 创建注意力层,设定查询维度、头数、键值维度等参数
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
# 创建丢弃层,设定丢弃概率
self.dropout = nn.Dropout(dropout_rate)
# 前向传播方法
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态张量
conditioning_emb: Optional[torch.Tensor] = None, # 可选的条件嵌入
attention_mask: Optional[torch.Tensor] = None, # 可选的注意力掩码
) -> torch.Tensor:
# 对输入的隐藏状态进行层归一化
normed_hidden_states = self.layer_norm(hidden_states)
# 如果有条件嵌入,应用 FiLM 层
if conditioning_emb is not None:
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
# 自注意力模块,获取注意力输出
attention_output = self.attention(normed_hidden_states)
# 将注意力输出与原隐藏状态相加,并应用丢弃层
hidden_states = hidden_states + self.dropout(attention_output)
# 返回更新后的隐藏状态
return hidden_states
# T5风格的交叉注意力层
class T5LayerCrossAttention(nn.Module):
r"""
T5风格的交叉注意力层。
参数:
d_model (`int`):
输入隐藏状态的大小。
d_kv (`int`):
键值投影向量的大小。
num_heads (`int`):
注意力头的数量。
dropout_rate (`float`):
丢弃概率。
layer_norm_epsilon (`float`):
用于数值稳定性的小值,避免除以零。
"""
# 初始化方法,设置模型参数
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
# 调用父类初始化方法
super().__init__()
# 创建注意力层
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
# 创建层归一化层
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
# 创建丢弃层
self.dropout = nn.Dropout(dropout_rate)
# 前向传播方法
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# 对隐藏状态进行层归一化
normed_hidden_states = self.layer_norm(hidden_states)
# 计算注意力输出
attention_output = self.attention(
normed_hidden_states,
encoder_hidden_states=key_value_states,
attention_mask=attention_mask.squeeze(1),
)
# 计算层输出,添加丢弃
layer_output = hidden_states + self.dropout(attention_output)
# 返回层输出
return layer_output
# T5风格的前馈条件层
class T5LayerFFCond(nn.Module):
r"""
T5风格的前馈条件层。
参数:
d_model (`int`):
输入隐藏状态的大小。
d_ff (`int`):
中间前馈层的大小。
dropout_rate (`float`):
丢弃概率。
layer_norm_epsilon (`float`):
用于数值稳定性的小值,避免除以零。
"""
# 初始化方法,设置模型参数
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
# 调用父类初始化方法
super().__init__()
# 创建带门激活的前馈层
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
# 创建条件层
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
# 创建层归一化层
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
# 创建丢弃层
self.dropout = nn.Dropout(dropout_rate)
# 前向传播方法
def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 对隐藏状态进行层归一化
forwarded_states = self.layer_norm(hidden_states)
# 如果存在条件嵌入,则应用条件层
if conditioning_emb is not None:
forwarded_states = self.film(forwarded_states, conditioning_emb)
# 应用前馈层
forwarded_states = self.DenseReluDense(forwarded_states)
# 更新隐藏状态,添加丢弃
hidden_states = hidden_states + self.dropout(forwarded_states)
# 返回更新后的隐藏状态
return hidden_states
# T5风格的前馈层,具有门控激活和丢弃
class T5DenseGatedActDense(nn.Module):
r"""
T5风格的前馈层,具有门控激活和丢弃。
# 参数说明部分
Args:
d_model (`int`): # 输入隐藏状态的尺寸
Size of the input hidden states.
d_ff (`int`): # 中间前馈层的尺寸
Size of the intermediate feed-forward layer.
dropout_rate (`float`): # 丢弃概率
Dropout probability.
"""
# 初始化方法,接受模型参数
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
super().__init__() # 调用父类的初始化方法
# 定义第一线性变换层,不使用偏置,输入维度为d_model,输出维度为d_ff
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
# 定义第二线性变换层,不使用偏置,输入维度为d_model,输出维度为d_ff
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
# 定义输出线性变换层,不使用偏置,输入维度为d_ff,输出维度为d_model
self.wo = nn.Linear(d_ff, d_model, bias=False)
# 定义丢弃层,使用指定的丢弃概率
self.dropout = nn.Dropout(dropout_rate)
# 初始化自定义激活函数
self.act = NewGELUActivation()
# 前向传播方法,接受输入的隐藏状态
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# 通过第一线性层和激活函数得到隐层状态
hidden_gelu = self.act(self.wi_0(hidden_states))
# 通过第二线性层得到隐层状态
hidden_linear = self.wi_1(hidden_states)
# 将两个隐层状态进行逐元素相乘
hidden_states = hidden_gelu * hidden_linear
# 应用丢弃层
hidden_states = self.dropout(hidden_states)
# 通过输出线性层得到最终的隐层状态
hidden_states = self.wo(hidden_states)
# 返回最终的隐层状态
return hidden_states
# T5风格的层归一化模块
class T5LayerNorm(nn.Module):
r"""
T5风格的层归一化模块。
Args:
hidden_size (`int`):
输入隐藏状态的大小。
eps (`float`, `optional`, defaults to `1e-6`):
用于数值稳定性的小值,以避免除以零。
"""
# 初始化函数,接受隐藏状态大小和epsilon
def __init__(self, hidden_size: int, eps: float = 1e-6):
"""
构造一个T5风格的层归一化模块。没有偏置,也不减去均值。
"""
# 调用父类构造函数
super().__init__()
# 初始化权重为全1的可学习参数
self.weight = nn.Parameter(torch.ones(hidden_size))
# 存储epsilon值
self.variance_epsilon = eps
# 前向传播函数,接受隐藏状态
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# 计算隐藏状态的方差,使用平方和的均值,保持维度
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
# 按照方差进行归一化,同时考虑到epsilon
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# 如果权重为半精度,转换隐藏状态为相应类型
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
# 返回归一化后的结果乘以权重
return self.weight * hidden_states
# 实现GELU激活函数的模块
class NewGELUActivation(nn.Module):
"""
实现与Google BERT库中相同的GELU激活函数(与OpenAI GPT相同)。也可以参考
Gaussian Error Linear Units论文:https://arxiv.org/abs/1606.08415
"""
# 前向传播函数,接受输入张量
def forward(self, input: torch.Tensor) -> torch.Tensor:
# 计算GELU激活值
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
# T5风格的FiLM层
class T5FiLMLayer(nn.Module):
"""
T5风格的FiLM层。
Args:
in_features (`int`):
输入特征的数量。
out_features (`int`):
输出特征的数量。
"""
# 初始化函数,接受输入和输出特征数量
def __init__(self, in_features: int, out_features: int):
# 调用父类构造函数
super().__init__()
# 定义线性层,用于生成缩放和偏移参数
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
# 前向传播函数,接受输入张量和条件嵌入
def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor:
# 通过线性层计算缩放和偏移
emb = self.scale_bias(conditioning_emb)
# 将结果分成缩放和偏移两个部分
scale, shift = torch.chunk(emb, 2, -1)
# 进行缩放和偏移操作
x = x * (1 + scale) + shift
# 返回处理后的结果
return x