diffusers 源码解析(十五)
.\diffusers\models\unets\unet_3d_condition.py
# 版权声明,声明此代码的版权信息和所有权
# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
# 版权声明,声明此代码的版权信息和所有权
# Copyright 2024 The ModelScope Team.
#
# 许可声明,声明本代码使用的 Apache 许可证 2.0 版本
# Licensed under the Apache License, Version 2.0 (the "License");
# 使用此文件前需遵守许可证规定
# you may not use this file except in compliance with the License.
# 可在以下网址获取许可证副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 免责声明,说明软件在许可下按 "原样" 提供,不附加任何明示或暗示的保证
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 许可证中规定的权限和限制说明
# See the License for the specific language governing permissions and
# limitations under the License.
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入所需的类型提示
from typing import Any, Dict, List, Optional, Tuple, Union
# 导入 PyTorch 库
import torch
# 导入 PyTorch 神经网络模块
import torch.nn as nn
# 导入 PyTorch 的检查点工具
import torch.utils.checkpoint
# 导入配置相关的工具类和函数
from ...configuration_utils import ConfigMixin, register_to_config
# 导入 UNet2D 条件加载器混合类
from ...loaders import UNet2DConditionLoadersMixin
# 导入基本输出类和日志工具
from ...utils import BaseOutput, logging
# 导入激活函数获取工具
from ..activations import get_activation
# 导入各种注意力处理器相关组件
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, # 导入添加键值对注意力处理器
CROSS_ATTENTION_PROCESSORS, # 导入交叉注意力处理器
Attention, # 导入注意力类
AttentionProcessor, # 导入注意力处理器基类
AttnAddedKVProcessor, # 导入添加键值对的注意力处理器
AttnProcessor, # 导入普通注意力处理器
FusedAttnProcessor2_0, # 导入融合注意力处理器
)
# 导入时间步嵌入和时间步类
from ..embeddings import TimestepEmbedding, Timesteps
# 导入模型混合类
from ..modeling_utils import ModelMixin
# 导入时间变换器模型
from ..transformers.transformer_temporal import TransformerTemporalModel
# 导入 3D UNet 相关的块
from .unet_3d_blocks import (
CrossAttnDownBlock3D, # 导入交叉注意力下采样块
CrossAttnUpBlock3D, # 导入交叉注意力上采样块
DownBlock3D, # 导入下采样块
UNetMidBlock3DCrossAttn, # 导入 UNet 中间交叉注意力块
UpBlock3D, # 导入上采样块
get_down_block, # 导入获取下采样块的函数
get_up_block, # 导入获取上采样块的函数
)
# 创建日志记录器,使用当前模块的名称
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 定义 UNet3DConditionOutput 数据类,继承自 BaseOutput
@dataclass
class UNet3DConditionOutput(BaseOutput):
"""
[`UNet3DConditionModel`] 的输出类。
参数:
sample (`torch.Tensor` 的形状为 `(batch_size, num_channels, num_frames, height, width)`):
基于 `encoder_hidden_states` 输入的隐藏状态输出。模型最后一层的输出。
"""
sample: torch.Tensor # 定义样本输出,类型为 PyTorch 张量
# 定义 UNet3DConditionModel 类,继承自多个混合类
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
条件 3D UNet 模型,接受噪声样本、条件状态和时间步,并返回形状为样本的输出。
此模型继承自 [`ModelMixin`]。有关其通用方法的文档,请参阅超类文档(如下载或保存)。
# 参数说明部分
Parameters:
# 输入/输出样本的高度和宽度,类型可以为整数或元组,默认为 None
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
# 输入样本的通道数,默认为 4
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
# 输出的通道数,默认为 4
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
# 使用的下采样块类型的元组,默认为指定的四种块
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`):
The tuple of downsample blocks to use.
# 使用的上采样块类型的元组,默认为指定的四种块
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`):
The tuple of upsample blocks to use.
# 每个块的输出通道数的元组,默认为 (320, 640, 1280, 1280)
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
# 每个块的层数,默认为 2
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
# 下采样卷积使用的填充,默认为 1
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
# 中间块使用的缩放因子,默认为 1.0
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
# 使用的激活函数,默认为 "silu"
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
# 用于归一化的组数,默认为 32;如果为 None,则跳过归一化和激活层
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing.
# 归一化使用的 epsilon 值,默认为 1e-5
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
# 交叉注意力特征的维度,默认为 1024
cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features.
# 注意力头的维度,默认为 64
attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
# 注意力头的数量,类型为整数,默认为 None
num_attention_heads (`int`, *optional*): The number of attention heads.
# 时间条件投影层的维度,默认为 None
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
"""
# 是否支持梯度检查点,默认为 False
_supports_gradient_checkpointing = False
# 将此类注册到配置中
@register_to_config
# 初始化方法,用于创建类的实例
def __init__(
# 样本大小,默认为 None
self,
sample_size: Optional[int] = None,
# 输入通道数量,默认为 4
in_channels: int = 4,
# 输出通道数量,默认为 4
out_channels: int = 4,
# 下采样块类型的元组,定义模型的下采样结构
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
),
# 上采样块类型的元组,定义模型的上采样结构
up_block_types: Tuple[str, ...] = (
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
),
# 每个块的输出通道数量,定义模型每个层的通道设置
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
# 每个块的层数,默认为 2
layers_per_block: int = 2,
# 下采样时的填充大小,默认为 1
downsample_padding: int = 1,
# 中间块的缩放因子,默认为 1
mid_block_scale_factor: float = 1,
# 激活函数类型,默认为 "silu"
act_fn: str = "silu",
# 归一化组的数量,默认为 32
norm_num_groups: Optional[int] = 32,
# 归一化的 epsilon 值,默认为 1e-5
norm_eps: float = 1e-5,
# 跨注意力维度,默认为 1024
cross_attention_dim: int = 1024,
# 注意力头的维度,可以是单一整数或整数元组,默认为 64
attention_head_dim: Union[int, Tuple[int]] = 64,
# 注意力头的数量,可选参数
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
# 时间条件投影维度,可选参数
time_cond_proj_dim: Optional[int] = None,
@property
# 从 UNet2DConditionModel 复制的属性,获取注意力处理器
# 返回所有注意力处理器的字典,以权重名称为索引
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# 初始化处理器字典
processors = {}
# 递归添加处理器的函数
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 如果模块有获取处理器的方法,添加到处理器字典中
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
# 遍历子模块,递归调用该函数
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回处理器字典
return processors
# 遍历当前类的子模块,调用递归添加处理器的函数
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
# 返回所有处理器
return processors
# 从 UNet2DConditionModel 复制的设置注意力切片的方法
# 从 UNet2DConditionModel 复制的设置注意力处理器的方法
# 定义一个方法用于设置注意力处理器
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
设置用于计算注意力的处理器。
参数:
processor(`dict` of `AttentionProcessor` 或仅 `AttentionProcessor`):
实例化的处理器类或一个处理器类的字典,将作为所有 `Attention` 层的处理器。
如果 `processor` 是一个字典,键需要定义相应的交叉注意力处理器的路径。
在设置可训练的注意力处理器时,强烈推荐这样做。
"""
# 获取当前注意力处理器的数量
count = len(self.attn_processors.keys())
# 如果传入的处理器是字典,且数量不等于注意力层数量,抛出错误
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"传入了一个处理器字典,但处理器的数量 {len(processor)} 与"
f" 注意力层的数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
)
# 定义一个递归函数来设置每个模块的处理器
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 如果模块有设置处理器的方法
if hasattr(module, "set_processor"):
# 如果处理器不是字典,直接设置处理器
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 从字典中获取相应的处理器并设置
module.set_processor(processor.pop(f"{name}.processor"))
# 遍历子模块并递归调用处理器设置
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍历当前对象的所有子模块,并调用递归设置函数
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# 定义一个方法来启用前馈层的分块处理
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
设置注意力处理器以使用 [前馈分块](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。
参数:
chunk_size (`int`, *可选*):
前馈层的分块大小。如果未指定,将对维度为`dim`的每个张量单独运行前馈层。
dim (`int`, *可选*, 默认为`0`):
应对哪个维度进行前馈计算的分块。可以选择 dim=0(批次)或 dim=1(序列长度)。
"""
# 确保 dim 参数为 0 或 1
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# 默认的分块大小为 1
chunk_size = chunk_size or 1
# 定义一个递归函数来设置每个模块的分块前馈处理
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模块具有设置分块前馈的属性,则设置它
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遍历子模块,递归调用函数
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 遍历当前实例的子模块,应用递归函数
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# 定义一个方法来禁用前馈层的分块处理
def disable_forward_chunking(self):
# 定义一个递归函数来禁用分块前馈处理
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模块具有设置分块前馈的属性,则设置为 None
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遍历子模块,递归调用函数
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 遍历当前实例的子模块,应用递归函数,禁用分块
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
# 从 diffusers.models.unets.unet_2d_condition 中复制的方法,设置默认注意力处理器
def set_default_attn_processor(self):
"""
禁用自定义注意力处理器并设置默认注意力实现。
"""
# 检查所有注意力处理器是否为添加的 KV 处理器
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor() # 设置为添加的 KV 处理器
# 检查所有注意力处理器是否为交叉注意力处理器
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor() # 设置为普通注意力处理器
else:
# 抛出异常,若注意力处理器类型不符合预期
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
# 设置选定的注意力处理器
self.set_attn_processor(processor)
# 定义一个私有方法来设置模块的梯度检查点
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
# 检查模块是否属于特定类型
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value # 设置梯度检查点值
# 从 diffusers.models.unets.unet_2d_condition 中复制的方法,启用自由度
# 启用 FreeU 机制,参数为两个缩放因子和两个增强因子的值
def enable_freeu(self, s1, s2, b1, b2):
r"""从 https://arxiv.org/abs/2309.11497 启用 FreeU 机制。
缩放因子的后缀表示它们应用的阶段块。
请参考 [官方仓库](https://github.com/ChenyangSi/FreeU) 以获取在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中已知效果良好的值组合。
Args:
s1 (`float`):
第1阶段的缩放因子,用于减弱跳跃特征的贡献,以减轻增强去噪过程中的“过平滑效应”。
s2 (`float`):
第2阶段的缩放因子,用于减弱跳跃特征的贡献,以减轻增强去噪过程中的“过平滑效应”。
b1 (`float`): 第1阶段的缩放因子,用于增强骨干特征的贡献。
b2 (`float`): 第2阶段的缩放因子,用于增强骨干特征的贡献。
"""
# 遍历上采样块,给每个块设置缩放因子和增强因子
for i, upsample_block in enumerate(self.up_blocks):
# 设置第1阶段的缩放因子
setattr(upsample_block, "s1", s1)
# 设置第2阶段的缩放因子
setattr(upsample_block, "s2", s2)
# 设置第1阶段的增强因子
setattr(upsample_block, "b1", b1)
# 设置第2阶段的增强因子
setattr(upsample_block, "b2", b2)
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu 复制
# 禁用 FreeU 机制
def disable_freeu(self):
"""禁用 FreeU 机制。"""
# 定义 FreeU 机制的关键属性
freeu_keys = {"s1", "s2", "b1", "b2"}
# 遍历上采样块
for i, upsample_block in enumerate(self.up_blocks):
# 遍历 FreeU 关键属性
for k in freeu_keys:
# 如果上采样块有该属性,或者该属性值不为 None
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
# 将属性值设置为 None,禁用 FreeU
setattr(upsample_block, k, None)
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections 复制
# 启用融合的 QKV 投影
def fuse_qkv_projections(self):
"""
启用融合的 QKV 投影。对于自注意力模块,所有投影矩阵(即查询、键、值)都被融合。对于交叉注意力模块,键和值投影矩阵被融合。
<Tip warning={true}>
此 API 是 🧪 实验性的。
</Tip>
"""
# 保存原始的注意力处理器
self.original_attn_processors = None
# 遍历注意力处理器
for _, attn_processor in self.attn_processors.items():
# 如果注意力处理器的类名中包含“Added”
if "Added" in str(attn_processor.__class__.__name__):
# 抛出错误,表示不支持具有附加 KV 投影的模型
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# 保存当前的注意力处理器
self.original_attn_processors = self.attn_processors
# 遍历所有模块
for module in self.modules():
# 如果模块是 Attention 类型
if isinstance(module, Attention):
# 融合投影
module.fuse_projections(fuse=True)
# 设置注意力处理器为融合的注意力处理器
self.set_attn_processor(FusedAttnProcessor2_0())
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 复制
# 定义一个方法,用于禁用已启用的融合 QKV 投影
def unfuse_qkv_projections(self):
"""禁用已启用的融合 QKV 投影。
<Tip warning={true}>
该 API 是 🧪 实验性的。
</Tip>
"""
# 如果存在原始的注意力处理器,则设置当前的注意力处理器为原始处理器
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
# 定义前向传播方法,接受多个参数进行计算
def forward(
self,
sample: torch.Tensor, # 输入样本,张量格式
timestep: Union[torch.Tensor, float, int], # 当前时间步,可以是张量、浮点数或整数
encoder_hidden_states: torch.Tensor, # 编码器的隐藏状态,张量格式
class_labels: Optional[torch.Tensor] = None, # 类别标签,默认为 None
timestep_cond: Optional[torch.Tensor] = None, # 时间步条件,默认为 None
attention_mask: Optional[torch.Tensor] = None, # 注意力掩码,默认为 None
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 跨注意力的关键字参数,默认为 None
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, # 降级块的附加残差,默认为 None
mid_block_additional_residual: Optional[torch.Tensor] = None, # 中间块的附加残差,默认为 None
return_dict: bool = True, # 是否返回字典格式的结果,默认为 True
.\diffusers\models\unets\unet_i2vgen_xl.py
# 版权声明,表明版权归2024年阿里巴巴DAMO-VILAB和HuggingFace团队所有
# 提供Apache许可证2.0版本的使用条款
# 说明只能在遵循许可证的情况下使用此文件
# 可在指定网址获取许可证副本
#
# 除非适用法律或书面协议另有约定,否则软件按“原样”分发
# 不提供任何形式的担保或条件
# 请参见许可证以获取与权限和限制相关的具体信息
from typing import Any, Dict, Optional, Tuple, Union # 导入类型提示工具,用于类型注解
import torch # 导入PyTorch库
import torch.nn as nn # 导入PyTorch的神经网络模块
import torch.utils.checkpoint # 导入PyTorch的检查点工具
from ...configuration_utils import ConfigMixin, register_to_config # 从配置工具导入类和函数
from ...loaders import UNet2DConditionLoadersMixin # 导入2D条件加载器混合类
from ...utils import logging # 导入日志工具
from ..activations import get_activation # 导入激活函数获取工具
from ..attention import Attention, FeedForward # 导入注意力机制和前馈网络
from ..attention_processor import ( # 从注意力处理器模块导入多个处理器
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps # 导入时间步嵌入和时间步类
from ..modeling_utils import ModelMixin # 导入模型混合类
from ..transformers.transformer_temporal import TransformerTemporalModel # 导入时间变换器模型
from .unet_3d_blocks import ( # 从3D U-Net块模块导入多个类
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn,
UpBlock3D,
get_down_block,
get_up_block,
)
from .unet_3d_condition import UNet3DConditionOutput # 导入3D条件输出类
logger = logging.get_logger(__name__) # 创建日志记录器,用于记录当前模块的信息
class I2VGenXLTransformerTemporalEncoder(nn.Module): # 定义一个名为I2VGenXLTransformerTemporalEncoder的类,继承自nn.Module
def __init__( # 构造函数,用于初始化类的实例
self,
dim: int, # 输入的特征维度
num_attention_heads: int, # 注意力头的数量
attention_head_dim: int, # 每个注意力头的维度
activation_fn: str = "geglu", # 激活函数类型,默认使用geglu
upcast_attention: bool = False, # 是否提升注意力计算的精度
ff_inner_dim: Optional[int] = None, # 前馈网络的内部维度,默认为None
dropout: int = 0.0, # dropout概率,默认为0.0
):
super().__init__() # 调用父类构造函数
self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5) # 初始化层归一化层
self.attn1 = Attention( # 初始化注意力层
query_dim=dim, # 查询维度
heads=num_attention_heads, # 注意力头数量
dim_head=attention_head_dim, # 每个头的维度
dropout=dropout, # dropout概率
bias=False, # 不使用偏置
upcast_attention=upcast_attention, # 是否提升注意力计算精度
out_bias=True, # 输出使用偏置
)
self.ff = FeedForward( # 初始化前馈网络
dim, # 输入维度
dropout=dropout, # dropout概率
activation_fn=activation_fn, # 激活函数类型
final_dropout=False, # 最后层不使用dropout
inner_dim=ff_inner_dim, # 内部维度
bias=True, # 使用偏置
)
def forward( # 定义前向传播方法
self,
hidden_states: torch.Tensor, # 输入的隐藏状态
# 该方法返回处理后的隐藏状态张量
) -> torch.Tensor:
# 对隐藏状态进行归一化处理
norm_hidden_states = self.norm1(hidden_states)
# 计算注意力输出,使用归一化后的隐藏状态
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
# 将注意力输出与原始隐藏状态相加,更新隐藏状态
hidden_states = attn_output + hidden_states
# 如果隐藏状态是四维,则去掉第一维
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 通过前馈网络处理隐藏状态
ff_output = self.ff(hidden_states)
# 将前馈输出与当前隐藏状态相加,更新隐藏状态
hidden_states = ff_output + hidden_states
# 如果隐藏状态是四维,则去掉第一维
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 返回最终的隐藏状态
return hidden_states
# 定义 I2VGenXL UNet 类,继承多个混入类以增加功能
class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
I2VGenXL UNet。一个条件3D UNet模型,接收噪声样本、条件状态和时间步,
返回与样本形状相同的输出。
该模型继承自 [`ModelMixin`]。有关所有模型实现的通用方法(如下载或保存),
请查看超类文档。
参数:
sample_size (`int` 或 `Tuple[int, int]`, *可选*, 默认值为 `None`):
输入/输出样本的高度和宽度。
in_channels (`int`, *可选*, 默认值为 4): 输入样本的通道数。
out_channels (`int`, *可选*, 默认值为 4): 输出样本的通道数。
down_block_types (`Tuple[str]`, *可选*, 默认值为 `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
使用的下采样块的元组。
up_block_types (`Tuple[str]`, *可选*, 默认值为 `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
使用的上采样块的元组。
block_out_channels (`Tuple[int]`, *可选*, 默认值为 `(320, 640, 1280, 1280)`):
每个块的输出通道元组。
layers_per_block (`int`, *可选*, 默认值为 2): 每个块的层数。
norm_num_groups (`int`, *可选*, 默认值为 32): 用于归一化的组数。
如果为 `None`,则跳过后处理中的归一化和激活层。
cross_attention_dim (`int`, *可选*, 默认值为 1280): 跨注意力特征的维度。
attention_head_dim (`int`, *可选*, 默认值为 64): 注意力头的维度。
num_attention_heads (`int`, *可选*): 注意力头的数量。
"""
# 设置不支持梯度检查点的属性为 False
_supports_gradient_checkpointing = False
@register_to_config
# 初始化方法,接受多种可选参数以设置模型配置
def __init__(
self,
sample_size: Optional[int] = None, # 输入/输出样本大小,默认为 None
in_channels: int = 4, # 输入样本的通道数,默认为 4
out_channels: int = 4, # 输出样本的通道数,默认为 4
down_block_types: Tuple[str, ...] = ( # 下采样块的类型,默认为指定的元组
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
),
up_block_types: Tuple[str, ...] = ( # 上采样块的类型,默认为指定的元组
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), # 每个块的输出通道,默认为指定的元组
layers_per_block: int = 2, # 每个块的层数,默认为 2
norm_num_groups: Optional[int] = 32, # 归一化组数,默认为 32
cross_attention_dim: int = 1024, # 跨注意力特征的维度,默认为 1024
attention_head_dim: Union[int, Tuple[int]] = 64, # 注意力头的维度,默认为 64
num_attention_heads: Optional[Union[int, Tuple[int]]] = None, # 注意力头的数量,默认为 None
@property
# 该属性从 UNet2DConditionModel 的 attn_processors 复制
# 定义返回注意力处理器的函数,返回类型为字典,键为字符串,值为 AttentionProcessor 对象
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# 创建一个空字典,用于存储处理器
processors = {}
# 定义一个递归函数,用于添加处理器到字典
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 检查模块是否有 get_processor 方法
if hasattr(module, "get_processor"):
# 将处理器添加到字典中,键为名称加上 ".processor"
processors[f"{name}.processor"] = module.get_processor()
# 遍历模块的子模块
for sub_name, child in module.named_children():
# 递归调用,处理子模块
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回更新后的处理器字典
return processors
# 遍历当前对象的所有子模块
for name, module in self.named_children():
# 调用递归函数,将处理器添加到字典中
fn_recursive_add_processors(name, module, processors)
# 返回包含所有处理器的字典
return processors
# 从 diffusers.models.unets.unet_2d_condition 中复制的设置注意力处理器的函数
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
# 获取当前注意力处理器的数量
count = len(self.attn_processors.keys())
# 如果传入的是字典且数量不匹配,则引发错误
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
# 定义一个递归函数,用于设置处理器
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 检查模块是否有 set_processor 方法
if hasattr(module, "set_processor"):
# 如果传入的处理器不是字典,直接设置
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 从字典中移除并设置对应的处理器
module.set_processor(processor.pop(f"{name}.processor"))
# 遍历模块的子模块
for sub_name, child in module.named_children():
# 递归调用,设置子模块的处理器
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍历当前对象的所有子模块
for name, module in self.named_children():
# 调用递归函数,为每个模块设置处理器
fn_recursive_attn_processor(name, module, processor)
# 从 diffusers.models.unets.unet_3d_condition 中复制的启用前向分块的函数
# 启用前馈层的分块处理
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
设置注意力处理器使用[前馈分块](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。
参数:
chunk_size (`int`, *可选*):
前馈层的块大小。如果未指定,将对维度为`dim`的每个张量单独运行前馈层。
dim (`int`, *可选*, 默认为`0`):
前馈计算应分块的维度。可以选择dim=0(批次)或dim=1(序列长度)。
"""
# 检查维度是否在有效范围内
if dim not in [0, 1]:
# 抛出错误,确保dim只为0或1
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# 默认块大小为1
chunk_size = chunk_size or 1
# 定义递归函数,用于设置每个模块的前馈分块
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模块有设置分块前馈的方法,调用该方法
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 递归遍历子模块
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 对当前对象的所有子模块应用前馈分块设置
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# 从diffusers.models.unets.unet_3d_condition.UNet3DConditionModel复制的禁用前馈分块的方法
def disable_forward_chunking(self):
# 定义递归函数,用于禁用模块的前馈分块
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模块有设置分块前馈的方法,调用该方法
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 递归遍历子模块
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 对当前对象的所有子模块应用禁用前馈分块设置
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
# 从diffusers.models.unets.unet_2d_condition.UNet2DConditionModel复制的设置默认注意力处理器的方法
def set_default_attn_processor(self):
"""
禁用自定义注意力处理器并设置默认的注意力实现。
"""
# 检查所有注意力处理器是否属于已添加的KV注意力处理器类
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 如果是,则设置为已添加KV处理器
processor = AttnAddedKVProcessor()
# 检查所有注意力处理器是否属于交叉注意力处理器类
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 如果是,则设置为标准注意力处理器
processor = AttnProcessor()
else:
# 抛出错误,说明当前的注意力处理器类型不被支持
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
# 设置当前对象的注意力处理器为选择的处理器
self.set_attn_processor(processor)
# 从diffusers.models.unets.unet_3d_condition.UNet3DConditionModel复制的设置梯度检查点的方法
# 设置梯度检查点,指定模块和布尔值
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
# 检查模块是否为指定的类型之一
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
# 设置模块的梯度检查点属性为指定值
module.gradient_checkpointing = value
# 从 UNet2DConditionModel 中复制的启用 FreeU 方法
def enable_freeu(self, s1, s2, b1, b2):
r"""启用 FreeU 机制,详情见 https://arxiv.org/abs/2309.11497.
后缀表示缩放因子应用的阶段块。
请参考 [官方库](https://github.com/ChenyangSi/FreeU) 以获取适用于不同管道(如 Stable Diffusion v1, v2 和 Stable Diffusion XL)的有效值组合。
参数:
s1 (`float`):
阶段 1 的缩放因子,用于减弱跳过特征的贡献,以缓解增强去噪过程中的“过平滑效应”。
s2 (`float`):
阶段 2 的缩放因子,用于减弱跳过特征的贡献,以缓解增强去噪过程中的“过平滑效应”。
b1 (`float`): 阶段 1 的缩放因子,用于放大主干特征的贡献。
b2 (`float`): 阶段 2 的缩放因子,用于放大主干特征的贡献。
"""
# 遍历上采样块,索引 i 和块对象 upsample_block
for i, upsample_block in enumerate(self.up_blocks):
# 设置上采样块的属性 s1 为给定值 s1
setattr(upsample_block, "s1", s1)
# 设置上采样块的属性 s2 为给定值 s2
setattr(upsample_block, "s2", s2)
# 设置上采样块的属性 b1 为给定值 b1
setattr(upsample_block, "b1", b1)
# 设置上采样块的属性 b2 为给定值 b2
setattr(upsample_block, "b2", b2)
# 从 UNet2DConditionModel 中复制的禁用 FreeU 方法
def disable_freeu(self):
"""禁用 FreeU 机制。"""
# 定义 FreeU 相关的属性键
freeu_keys = {"s1", "s2", "b1", "b2"}
# 遍历上采样块,索引 i 和块对象 upsample_block
for i, upsample_block in enumerate(self.up_blocks):
# 遍历 FreeU 属性键
for k in freeu_keys:
# 如果上采样块具有该属性或属性值不为 None
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
# 将上采样块的该属性设置为 None
setattr(upsample_block, k, None)
# 从 UNet2DConditionModel 中复制的融合 QKV 投影方法
# 定义一个方法,用于启用融合的 QKV 投影
def fuse_qkv_projections(self):
# 提供方法的文档字符串,描述其功能和警告信息
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
# 初始化原始注意力处理器为 None
self.original_attn_processors = None
# 遍历当前对象的注意力处理器
for _, attn_processor in self.attn_processors.items():
# 检查处理器类名中是否包含 "Added"
if "Added" in str(attn_processor.__class__.__name__):
# 如果包含,抛出异常提示不支持
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# 保存当前的注意力处理器以备后用
self.original_attn_processors = self.attn_processors
# 遍历当前对象的所有模块
for module in self.modules():
# 检查模块是否为 Attention 类型
if isinstance(module, Attention):
# 调用模块的方法,启用融合投影
module.fuse_projections(fuse=True)
# 设置注意力处理器为 FusedAttnProcessor2_0 的实例
self.set_attn_processor(FusedAttnProcessor2_0())
# 从 UNet2DConditionModel 复制的方法,用于禁用融合的 QKV 投影
def unfuse_qkv_projections(self):
# 提供方法的文档字符串,描述其功能和警告信息
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
# 检查原始注意力处理器是否不为 None
if self.original_attn_processors is not None:
# 如果不为 None,恢复原始的注意力处理器
self.set_attn_processor(self.original_attn_processors)
# 定义前向传播方法,接受多个输入参数
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
fps: torch.Tensor,
image_latents: torch.Tensor,
image_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
.\diffusers\models\unets\unet_kandinsky3.py
# 版权声明,指明该文件属于 HuggingFace 团队,所有权利保留
#
# 根据 Apache License 2.0 版(“许可证”)授权;
# 除非遵循许可证,否则不得使用此文件。
# 可以在以下网址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件
# 是在“原样”基础上分发的,不附带任何形式的保证或条件。
# 有关特定语言的许可条款和条件,请参见许可证。
from dataclasses import dataclass # 从 dataclasses 模块导入 dataclass 装饰器
from typing import Dict, Tuple, Union # 导入用于类型提示的字典、元组和联合类型
import torch # 导入 PyTorch 库
import torch.utils.checkpoint # 导入 PyTorch 的检查点工具
from torch import nn # 从 PyTorch 导入神经网络模块
from ...configuration_utils import ConfigMixin, register_to_config # 从配置工具导入混合类和注册函数
from ...utils import BaseOutput, logging # 从工具模块导入基础输出类和日志功能
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor # 导入注意力处理器相关类
from ..embeddings import TimestepEmbedding, Timesteps # 导入时间步嵌入相关类
from ..modeling_utils import ModelMixin # 导入模型混合类
logger = logging.get_logger(__name__) # 创建一个记录器,用于当前模块的日志记录
@dataclass # 将该类标记为数据类,以简化初始化和表示
class Kandinsky3UNetOutput(BaseOutput): # 定义 Kandinsky3UNetOutput 类,继承自 BaseOutput
sample: torch.Tensor = None # 定义输出样本,默认为 None
class Kandinsky3EncoderProj(nn.Module): # 定义 Kandinsky3EncoderProj 类,继承自 nn.Module
def __init__(self, encoder_hid_dim, cross_attention_dim): # 初始化方法,接收隐藏维度和交叉注意力维度
super().__init__() # 调用父类的初始化方法
self.projection_linear = nn.Linear(encoder_hid_dim, cross_attention_dim, bias=False) # 定义线性投影层,不使用偏置
self.projection_norm = nn.LayerNorm(cross_attention_dim) # 定义层归一化层
def forward(self, x): # 定义前向传播方法
x = self.projection_linear(x) # 通过线性层处理输入
x = self.projection_norm(x) # 通过层归一化处理输出
return x # 返回处理后的结果
class Kandinsky3UNet(ModelMixin, ConfigMixin): # 定义 Kandinsky3UNet 类,继承自 ModelMixin 和 ConfigMixin
@register_to_config # 将该方法注册到配置中
def __init__( # 初始化方法
self,
in_channels: int = 4, # 输入通道数,默认值为 4
time_embedding_dim: int = 1536, # 时间嵌入维度,默认值为 1536
groups: int = 32, # 组数,默认值为 32
attention_head_dim: int = 64, # 注意力头维度,默认值为 64
layers_per_block: Union[int, Tuple[int]] = 3, # 每个块的层数,默认值为 3,可以是整数或元组
block_out_channels: Tuple[int] = (384, 768, 1536, 3072), # 块输出通道,默认为指定元组
cross_attention_dim: Union[int, Tuple[int]] = 4096, # 交叉注意力维度,默认值为 4096
encoder_hid_dim: int = 4096, # 编码器隐藏维度,默认值为 4096
@property # 定义一个属性
def attn_processors(self) -> Dict[str, AttentionProcessor]: # 返回注意力处理器字典
r""" # 文档字符串,描述该方法的功能
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# 设置一个空字典以递归存储处理器
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): # 定义递归函数添加处理器
if hasattr(module, "set_processor"): # 检查模块是否具有 set_processor 属性
processors[f"{name}.processor"] = module.processor # 将处理器添加到字典中
for sub_name, child in module.named_children(): # 遍历模块的所有子模块
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) # 递归调用自身
return processors # 返回更新后的处理器字典
for name, module in self.named_children(): # 遍历当前类的所有子模块
fn_recursive_add_processors(name, module, processors) # 调用递归函数
return processors # 返回包含所有处理器的字典
# 定义设置注意力处理器的方法,参数为处理器,可以是 AttentionProcessor 类或其字典形式
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
设置用于计算注意力的处理器。
参数:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
实例化的处理器类或处理器类的字典,将作为 **所有** `Attention` 层的处理器。
如果 `processor` 是一个字典,键需要定义相应交叉注意力处理器的路径。这在设置可训练注意力处理器时强烈推荐。
"""
# 获取当前注意力处理器的数量
count = len(self.attn_processors.keys())
# 如果传入的是字典且其长度与注意力层的数量不匹配,则抛出错误
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"传入了处理器字典,但处理器的数量 {len(processor)} 与"
f" 注意力层的数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
)
# 定义递归设置注意力处理器的方法
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 如果模块有设置处理器的方法
if hasattr(module, "set_processor"):
# 如果处理器不是字典,则直接设置
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 从字典中获取对应的处理器并设置
module.set_processor(processor.pop(f"{name}.processor"))
# 遍历模块的所有子模块
for sub_name, child in module.named_children():
# 递归调用处理子模块
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍历当前对象的所有子模块
for name, module in self.named_children():
# 递归设置每个子模块的处理器
fn_recursive_attn_processor(name, module, processor)
# 定义设置默认注意力处理器的方法
def set_default_attn_processor(self):
"""
禁用自定义注意力处理器,并设置默认的注意力实现。
"""
# 调用设置注意力处理器的方法,使用默认的 AttnProcessor 实例
self.set_attn_processor(AttnProcessor())
# 定义设置梯度检查点的方法
def _set_gradient_checkpointing(self, module, value=False):
# 如果模块有梯度检查点的属性
if hasattr(module, "gradient_checkpointing"):
# 设置该属性为指定的值
module.gradient_checkpointing = value
# 定义前向传播函数,接收样本、时间步以及可选的编码器隐藏状态和注意力掩码
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
# 如果存在编码器注意力掩码,则进行调整以适应后续计算
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
# 增加一个维度,以便后续处理
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 检查时间步是否为张量类型
if not torch.is_tensor(timestep):
# 根据时间步类型确定数据类型
dtype = torch.float32 if isinstance(timestep, float) else torch.int32
# 将时间步转换为张量并指定设备
timestep = torch.tensor([timestep], dtype=dtype, device=sample.device)
# 如果时间步为标量,则扩展为一维张量
elif len(timestep.shape) == 0:
timestep = timestep[None].to(sample.device)
# 扩展时间步到与批量维度兼容的形状
timestep = timestep.expand(sample.shape[0])
# 通过时间投影获取时间嵌入输入并转换为样本的数据类型
time_embed_input = self.time_proj(timestep).to(sample.dtype)
# 获取时间嵌入
time_embed = self.time_embedding(time_embed_input)
# 对编码器隐藏状态进行线性变换
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
# 如果存在编码器隐藏状态,则将时间嵌入与隐藏状态结合
if encoder_hidden_states is not None:
time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask)
# 初始化隐藏状态列表
hidden_states = []
# 对输入样本进行初步卷积处理
sample = self.conv_in(sample)
# 遍历下采样块
for level, down_sample in enumerate(self.down_blocks):
# 通过下采样块处理样本
sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
# 如果不是最后一个层级,记录当前样本状态
if level != self.num_levels - 1:
hidden_states.append(sample)
# 遍历上采样块
for level, up_sample in enumerate(self.up_blocks):
# 如果不是第一个层级,则拼接当前样本与之前的隐藏状态
if level != 0:
sample = torch.cat([sample, hidden_states.pop()], dim=1)
# 通过上采样块处理样本
sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
# 进行输出卷积规范化
sample = self.conv_norm_out(sample)
# 进行输出激活
sample = self.conv_act_out(sample)
# 进行最终输出卷积
sample = self.conv_out(sample)
# 根据返回标志返回相应的结果
if not return_dict:
return (sample,)
# 返回结果对象
return Kandinsky3UNetOutput(sample=sample)
# 定义 Kandinsky3UpSampleBlock 类,继承自 nn.Module
class Kandinsky3UpSampleBlock(nn.Module):
# 初始化方法,设置各参数
def __init__(
self,
in_channels, # 输入通道数
cat_dim, # 拼接维度
out_channels, # 输出通道数
time_embed_dim, # 时间嵌入维度
context_dim=None, # 上下文维度,可选
num_blocks=3, # 块的数量
groups=32, # 分组数
head_dim=64, # 头维度
expansion_ratio=4, # 扩展比例
compression_ratio=2, # 压缩比例
up_sample=True, # 是否上采样
self_attention=True, # 是否使用自注意力
):
# 调用父类初始化方法
super().__init__()
# 设置上采样分辨率
up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1)
# 设置隐藏通道数
hidden_channels = (
[(in_channels + cat_dim, in_channels)] # 第一层的通道
+ [(in_channels, in_channels)] * (num_blocks - 2) # 中间层的通道
+ [(in_channels, out_channels)] # 最后一层的通道
)
attentions = [] # 用于存储注意力块
resnets_in = [] # 用于存储输入 ResNet 块
resnets_out = [] # 用于存储输出 ResNet 块
# 设置自注意力和上下文维度
self.self_attention = self_attention
self.context_dim = context_dim
# 如果使用自注意力,添加注意力块
if self_attention:
attentions.append(
Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
)
else:
attentions.append(nn.Identity()) # 否则添加身份映射
# 遍历隐藏通道和上采样分辨率
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
# 添加输入 ResNet 块
resnets_in.append(
Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution)
)
# 如果上下文维度不为 None,添加注意力块
if context_dim is not None:
attentions.append(
Kandinsky3AttentionBlock(
in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
)
)
else:
attentions.append(nn.Identity()) # 否则添加身份映射
# 添加输出 ResNet 块
resnets_out.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
)
# 将注意力块和 ResNet 块转换为模块列表
self.attentions = nn.ModuleList(attentions)
self.resnets_in = nn.ModuleList(resnets_in)
self.resnets_out = nn.ModuleList(resnets_out)
# 前向传播方法
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
# 遍历注意力块和 ResNet 块进行前向计算
for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
x = resnet_in(x, time_embed) # 输入经过 ResNet 块
if self.context_dim is not None: # 如果上下文维度存在
x = attention(x, time_embed, context, context_mask, image_mask) # 应用注意力块
x = resnet_out(x, time_embed) # 输出经过 ResNet 块
# 如果使用自注意力,应用首个注意力块
if self.self_attention:
x = self.attentions[0](x, time_embed, image_mask=image_mask)
return x # 返回处理后的结果
# 定义 Kandinsky3DownSampleBlock 类,继承自 nn.Module
class Kandinsky3DownSampleBlock(nn.Module):
# 初始化方法,设置各参数
def __init__(
self,
in_channels, # 输入通道数
out_channels, # 输出通道数
time_embed_dim, # 时间嵌入维度
context_dim=None, # 上下文维度,可选
num_blocks=3, # 块的数量
groups=32, # 分组数
head_dim=64, # 头维度
expansion_ratio=4, # 扩展比例
compression_ratio=2, # 压缩比例
down_sample=True, # 是否下采样
self_attention=True, # 是否使用自注意力
):
# 调用父类的初始化方法
super().__init__()
# 初始化注意力模块列表
attentions = []
# 初始化输入残差块列表
resnets_in = []
# 初始化输出残差块列表
resnets_out = []
# 保存自注意力标志
self.self_attention = self_attention
# 保存上下文维度
self.context_dim = context_dim
# 如果启用自注意力
if self_attention:
# 添加 Kandinsky3AttentionBlock 到注意力列表
attentions.append(
Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
)
else:
# 否则添加身份层(不改变输入)
attentions.append(nn.Identity())
# 生成上采样分辨率列表
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]]
# 生成隐藏通道的元组列表
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
# 遍历隐藏通道和上采样分辨率
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
# 添加输入残差块到列表
resnets_in.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
)
# 如果上下文维度不为 None
if context_dim is not None:
# 添加 Kandinsky3AttentionBlock 到注意力列表
attentions.append(
Kandinsky3AttentionBlock(
out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
)
)
else:
# 否则添加身份层(不改变输入)
attentions.append(nn.Identity())
# 添加输出残差块到列表
resnets_out.append(
Kandinsky3ResNetBlock(
out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution
)
)
# 将注意力模块列表转换为 nn.ModuleList 以便管理
self.attentions = nn.ModuleList(attentions)
# 将输入残差块列表转换为 nn.ModuleList 以便管理
self.resnets_in = nn.ModuleList(resnets_in)
# 将输出残差块列表转换为 nn.ModuleList 以便管理
self.resnets_out = nn.ModuleList(resnets_out)
# 定义前向传播方法
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
# 如果启用自注意力
if self.self_attention:
# 使用第一个注意力模块处理输入
x = self.attentions[0](x, time_embed, image_mask=image_mask)
# 遍历剩余的注意力模块、输入和输出残差块
for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
# 通过输入残差块处理输入
x = resnet_in(x, time_embed)
# 如果上下文维度不为 None
if self.context_dim is not None:
# 使用当前注意力模块处理输入
x = attention(x, time_embed, context, context_mask, image_mask)
# 通过输出残差块处理输入
x = resnet_out(x, time_embed)
# 返回处理后的输出
return x
# 定义 Kandinsky3ConditionalGroupNorm 类,继承自 nn.Module
class Kandinsky3ConditionalGroupNorm(nn.Module):
# 初始化方法,设置分组数、标准化形状和上下文维度
def __init__(self, groups, normalized_shape, context_dim):
# 调用父类构造函数
super().__init__()
# 创建分组归一化层,不使用仿射变换
self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
# 定义上下文多层感知机,包含 SiLU 激活和线性层
self.context_mlp = nn.Sequential(nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape))
# 将线性层的权重初始化为零
self.context_mlp[1].weight.data.zero_()
# 将线性层的偏置初始化为零
self.context_mlp[1].bias.data.zero_()
# 前向传播方法,接收输入和上下文
def forward(self, x, context):
# 通过上下文多层感知机处理上下文
context = self.context_mlp(context)
# 为了匹配输入的维度,逐层扩展上下文
for _ in range(len(x.shape[2:])):
context = context.unsqueeze(-1)
# 将上下文分割为缩放和偏移量
scale, shift = context.chunk(2, dim=1)
# 应用归一化并进行缩放和偏移
x = self.norm(x) * (scale + 1.0) + shift
# 返回处理后的输入
return x
# 定义 Kandinsky3Block 类,继承自 nn.Module
class Kandinsky3Block(nn.Module):
# 初始化方法,设置输入通道、输出通道、时间嵌入维度等参数
def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
# 调用父类构造函数
super().__init__()
# 创建条件分组归一化层
self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
# 定义 SiLU 激活函数
self.activation = nn.SiLU()
# 如果需要上采样,使用转置卷积进行上采样
if up_resolution is not None and up_resolution:
self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
else:
# 否则使用恒等映射
self.up_sample = nn.Identity()
# 根据卷积核大小确定填充
padding = int(kernel_size > 1)
# 定义卷积投影层
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
# 如果不需要上采样,定义下采样卷积层
if up_resolution is not None and not up_resolution:
self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
else:
# 否则使用恒等映射
self.down_sample = nn.Identity()
# 前向传播方法,接收输入和时间嵌入
def forward(self, x, time_embed):
# 通过条件分组归一化处理输入
x = self.group_norm(x, time_embed)
# 应用激活函数
x = self.activation(x)
# 进行上采样
x = self.up_sample(x)
# 通过卷积投影层处理输入
x = self.projection(x)
# 进行下采样
x = self.down_sample(x)
# 返回处理后的输出
return x
# 定义 Kandinsky3ResNetBlock 类,继承自 nn.Module
class Kandinsky3ResNetBlock(nn.Module):
# 初始化方法,设置输入通道、输出通道、时间嵌入维度等参数
def __init__(
self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None]
# 初始化父类
):
super().__init__()
# 定义卷积核的大小
kernel_sizes = [1, 3, 3, 1]
# 计算隐藏通道数
hidden_channel = max(in_channels, out_channels) // compression_ratio
# 构建隐藏通道的元组列表
hidden_channels = (
[(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
)
# 创建包含多个 Kandinsky3Block 的模块列表
self.resnet_blocks = nn.ModuleList(
[
Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
# 将隐藏通道、卷积核大小和上采样分辨率组合在一起
for (in_channel, out_channel), kernel_size, up_resolution in zip(
hidden_channels, kernel_sizes, up_resolutions
)
]
)
# 定义上采样的快捷连接
self.shortcut_up_sample = (
nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
# 如果存在上采样分辨率,则使用反卷积;否则使用恒等映射
if True in up_resolutions
else nn.Identity()
)
# 定义通道数不同时的投影连接
self.shortcut_projection = (
nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
)
# 定义下采样的快捷连接
self.shortcut_down_sample = (
nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
# 如果存在下采样分辨率,则使用卷积;否则使用恒等映射
if False in up_resolutions
else nn.Identity()
)
# 前向传播方法
def forward(self, x, time_embed):
# 初始化输出为输入
out = x
# 依次通过每个 ResNet 块
for resnet_block in self.resnet_blocks:
out = resnet_block(out, time_embed)
# 上采样输入
x = self.shortcut_up_sample(x)
# 投影输入到输出通道
x = self.shortcut_projection(x)
# 下采样输入
x = self.shortcut_down_sample(x)
# 将输出与处理后的输入相加
x = x + out
# 返回最终输出
return x
# 定义 Kandinsky3AttentionPooling 类,继承自 nn.Module
class Kandinsky3AttentionPooling(nn.Module):
# 初始化方法,接受通道数、上下文维度和头维度
def __init__(self, num_channels, context_dim, head_dim=64):
# 调用父类构造函数
super().__init__()
# 创建注意力机制对象,指定输入和输出维度及其他参数
self.attention = Attention(
context_dim,
context_dim,
dim_head=head_dim,
out_dim=num_channels,
out_bias=False,
)
# 前向传播方法
def forward(self, x, context, context_mask=None):
# 将上下文掩码转换为与上下文相同的数据类型
context_mask = context_mask.to(dtype=context.dtype)
# 使用注意力机制计算上下文与其平均值的加权和
context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
# 返回输入与上下文的和
return x + context.squeeze(1)
# 定义 Kandinsky3AttentionBlock 类,继承自 nn.Module
class Kandinsky3AttentionBlock(nn.Module):
# 初始化方法,接受多种参数
def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
# 调用父类构造函数
super().__init__()
# 创建条件组归一化对象,用于输入规范化
self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
# 创建注意力机制对象,指定输入和输出维度及其他参数
self.attention = Attention(
num_channels,
context_dim or num_channels,
dim_head=head_dim,
out_dim=num_channels,
out_bias=False,
)
# 计算隐藏通道数,作为扩展比和通道数的乘积
hidden_channels = expansion_ratio * num_channels
# 创建条件组归一化对象,用于输出规范化
self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
# 定义前馈网络,包含两个卷积层和激活函数
self.feed_forward = nn.Sequential(
nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
nn.SiLU(),
nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
)
# 前向传播方法
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
# 获取输入的高度和宽度
height, width = x.shape[-2:]
# 对输入进行归一化处理
out = self.in_norm(x, time_embed)
# 将输出重塑为适合注意力机制的形状
out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1)
# 如果没有上下文,则使用当前的输出作为上下文
context = context if context is not None else out
# 如果存在上下文掩码,转换为与上下文相同的数据类型
if context_mask is not None:
context_mask = context_mask.to(dtype=context.dtype)
# 使用注意力机制处理输出和上下文
out = self.attention(out, context, context_mask)
# 重塑输出为原始输入形状
out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
# 将处理后的输出与原输入相加
x = x + out
# 对相加后的结果进行输出归一化
out = self.out_norm(x, time_embed)
# 通过前馈网络处理归一化输出
out = self.feed_forward(out)
# 将处理后的输出与相加后的输入相加
x = x + out
# 返回最终输出
return x
.\diffusers\models\unets\unet_motion_model.py
# 版权声明,表明该文件的所有权及相关使用条款
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache License, Version 2.0 (“许可证”) 授权;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,否则根据许可证分发的软件
# 是在“按原样”基础上分发的,不提供任何形式的保证或条件,
# 无论是明示还是暗示。
# 有关许可证所管辖的权限和限制,请参见许可证。
#
# 导入所需的库和模块
from dataclasses import dataclass # 导入数据类装饰器
from typing import Any, Dict, Optional, Tuple, Union # 导入类型提示相关的类型
import torch # 导入 PyTorch 库
import torch.nn as nn # 导入 PyTorch 的神经网络模块
import torch.nn.functional as F # 导入 PyTorch 的功能性神经网络模块
import torch.utils.checkpoint # 导入 PyTorch 的检查点功能
# 导入自定义的配置和加载工具
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, is_torch_version, logging # 导入常用的工具函数
from ...utils.torch_utils import apply_freeu # 导入应用 FreeU 的工具函数
from ..attention import BasicTransformerBlock # 导入基础变换器模块
from ..attention_processor import ( # 导入注意力处理器相关的类
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
AttnProcessor2_0,
FusedAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps # 导入时间步嵌入相关的类
from ..modeling_utils import ModelMixin # 导入模型混合工具
from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D # 导入 ResNet 相关的模块
from ..transformers.dual_transformer_2d import DualTransformer2DModel # 导入双重变换器模型
from ..transformers.transformer_2d import Transformer2DModel # 导入 2D 变换器模型
from .unet_2d_blocks import UNetMidBlock2DCrossAttn # 导入 U-Net 中间块
from .unet_2d_condition import UNet2DConditionModel # 导入条件 U-Net 模型
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器,便于调试和日志输出
@dataclass
class UNetMotionOutput(BaseOutput): # 定义 UNetMotionOutput 数据类,继承自 BaseOutput
"""
[`UNetMotionOutput`] 的输出。
参数:
sample (`torch.Tensor` 的形状为 `(batch_size, num_channels, num_frames, height, width)`):
基于 `encoder_hidden_states` 输入的隐藏状态输出。模型最后一层的输出。
"""
sample: torch.Tensor # 定义 sample 属性,类型为 torch.Tensor
class AnimateDiffTransformer3D(nn.Module): # 定义 AnimateDiffTransformer3D 类,继承自 nn.Module
"""
一个用于视频类数据的变换器模型。
# 参数说明部分,描述初始化函数中每个参数的用途
Parameters:
# 多头注意力机制中头的数量,默认为16
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
# 每个头中的通道数,默认为88
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
# 输入和输出的通道数,如果输入是**连续**,则需要指定
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
# Transformer块的层数,默认为1
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
# dropout概率,默认为0.0
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
# 使用的`encoder_hidden_states`维度数
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
# 配置`TransformerBlock`的注意力是否包含偏置参数
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
# 潜在图像的宽度,如果输入是**离散**,则需要指定
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
# 该值在训练期间固定,用于学习位置嵌入的数量
This is fixed during training since it is used to learn a number of position embeddings.
# 前馈中的激活函数,默认为"geglu"
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
# 配置`TransformerBlock`是否使用可学习的逐元素仿射参数进行归一化
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
# 配置每个`TransformerBlock`是否包含两个自注意力层
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
# 应用到序列输入的位置信息嵌入的类型
positional_embeddings: (`str`, *optional*):
The type of positional embeddings to apply to the sequence input before passing use.
# 应用位置嵌入的最大序列长度
num_positional_embeddings: (`int`, *optional*):
The maximum length of the sequence over which to apply positional embeddings.
"""
# 初始化方法定义
def __init__(
# 多头注意力机制中头的数量,默认为16
self,
num_attention_heads: int = 16,
# 每个头中的通道数,默认为88
attention_head_dim: int = 88,
# 输入通道数,可选
in_channels: Optional[int] = None,
# 输出通道数,可选
out_channels: Optional[int] = None,
# Transformer块的层数,默认为1
num_layers: int = 1,
# dropout概率,默认为0.0
dropout: float = 0.0,
# 归一化分组数,默认为32
norm_num_groups: int = 32,
# 使用的`encoder_hidden_states`维度数,可选
cross_attention_dim: Optional[int] = None,
# 注意力是否包含偏置参数,默认为False
attention_bias: bool = False,
# 潜在图像的宽度,可选
sample_size: Optional[int] = None,
# 前馈中的激活函数,默认为"geglu"
activation_fn: str = "geglu",
# 归一化是否使用可学习的逐元素仿射参数,默认为True
norm_elementwise_affine: bool = True,
# 每个`TransformerBlock`是否包含两个自注意力层,默认为True
double_self_attention: bool = True,
# 位置信息嵌入的类型,可选
positional_embeddings: Optional[str] = None,
# 应用位置嵌入的最大序列长度,可选
num_positional_embeddings: Optional[int] = None,
):
# 调用父类的构造函数以初始化父类的属性
super().__init__()
# 设置注意力头的数量
self.num_attention_heads = num_attention_heads
# 设置每个注意力头的维度
self.attention_head_dim = attention_head_dim
# 计算内部维度,等于注意力头数量与每个注意力头维度的乘积
inner_dim = num_attention_heads * attention_head_dim
# 设置输入通道数
self.in_channels = in_channels
# 定义归一化层,使用组归一化,允许可学习的偏移
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
# 定义输入线性变换层,将输入通道映射到内部维度
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. 定义变换器块
self.transformer_blocks = nn.ModuleList(
[
# 创建指定数量的基本变换器块
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
positional_embeddings=positional_embeddings,
num_positional_embeddings=num_positional_embeddings,
)
# 遍历创建 num_layers 个基本变换器块
for _ in range(num_layers)
]
)
# 定义输出线性变换层,将内部维度映射回输入通道数
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(
# 定义前向传播方法的输入参数
self,
hidden_states: torch.Tensor, # 输入的隐藏状态张量
encoder_hidden_states: Optional[torch.LongTensor] = None, # 编码器的隐藏状态,默认为 None
timestep: Optional[torch.LongTensor] = None, # 时间步,默认为 None
class_labels: Optional[torch.LongTensor] = None, # 类标签,默认为 None
num_frames: int = 1, # 帧数,默认值为 1
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 跨注意力参数,默认为 None
# 该方法用于 [`AnimateDiffTransformer3D`] 的前向传播
) -> torch.Tensor:
"""
方法参数说明:
hidden_states (`torch.LongTensor`): 输入的隐状态,形状为 `(batch size, num latent pixels)` 或 `(batch size, channel, height, width)`
encoder_hidden_states ( `torch.LongTensor`, *可选*):
交叉注意力层的条件嵌入。如果未提供,交叉注意力将默认使用自注意力。
timestep ( `torch.LongTensor`, *可选*):
用于指示去噪步骤的时间戳。
class_labels ( `torch.LongTensor`, *可选*):
用于指示类别标签的条件嵌入。
num_frames (`int`, *可选*, 默认为 1):
每个批次处理的帧数,用于重新形状隐状态。
cross_attention_kwargs (`dict`, *可选*):
可选的关键字字典,传递给 `AttentionProcessor`。
返回值:
torch.Tensor:
输出张量。
"""
# 1. 输入
# 获取输入隐状态的形状信息
batch_frames, channel, height, width = hidden_states.shape
# 计算批次大小
batch_size = batch_frames // num_frames
# 将隐状态保留用于残差连接
residual = hidden_states
# 调整隐状态的形状以适应批次和帧数
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
# 调整维度顺序以便后续处理
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
# 对隐状态进行规范化
hidden_states = self.norm(hidden_states)
# 再次调整维度顺序并重塑为适当的形状
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
# 输入层投影
hidden_states = self.proj_in(hidden_states)
# 2. 处理块
# 遍历每个变换块以处理隐状态
for block in self.transformer_blocks:
hidden_states = block(
hidden_states, # 当前的隐状态
encoder_hidden_states=encoder_hidden_states, # 可选的编码器隐状态
timestep=timestep, # 可选的时间戳
cross_attention_kwargs=cross_attention_kwargs, # 可选的交叉注意力参数
class_labels=class_labels, # 可选的类标签
)
# 3. 输出
# 输出层投影
hidden_states = self.proj_out(hidden_states)
# 调整输出张量的形状
hidden_states = (
hidden_states[None, None, :] # 添加维度
.reshape(batch_size, height, width, num_frames, channel) # 重塑为适当形状
.permute(0, 3, 4, 1, 2) # 调整维度顺序
.contiguous() # 确保内存连续性
)
# 最终调整输出的形状
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
# 将残差添加到输出中以形成最终输出
output = hidden_states + residual
# 返回最终的输出张量
return output
# 定义一个名为 DownBlockMotion 的类,继承自 nn.Module
class DownBlockMotion(nn.Module):
# 初始化方法,定义多个参数,包括输入输出通道、dropout 率等
def __init__(
self,
in_channels: int, # 输入通道数量
out_channels: int, # 输出通道数量
temb_channels: int, # 时间嵌入通道数量
dropout: float = 0.0, # dropout 率,默认为 0
num_layers: int = 1, # 网络层数,默认为 1
resnet_eps: float = 1e-6, # ResNet 的 epsilon 参数
resnet_time_scale_shift: str = "default", # ResNet 时间尺度偏移
resnet_act_fn: str = "swish", # ResNet 激活函数,默认为 swish
resnet_groups: int = 32, # ResNet 组数,默认为 32
resnet_pre_norm: bool = True, # ResNet 是否使用预归一化
output_scale_factor: float = 1.0, # 输出缩放因子
add_downsample: bool = True, # 是否添加下采样层
downsample_padding: int = 1, # 下采样时的填充
temporal_num_attention_heads: Union[int, Tuple[int]] = 1, # 时间注意力头数
temporal_cross_attention_dim: Optional[int] = None, # 时间交叉注意力维度
temporal_max_seq_length: int = 32, # 最大序列长度
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 每个块的变换器层数
temporal_double_self_attention: bool = True, # 是否双重自注意力
):
# 前向传播方法,接收隐藏状态和时间嵌入等参数
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态张量
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
num_frames: int = 1, # 帧数,默认为 1
*args, # 接受任意位置参数
**kwargs, # 接受任意关键字参数
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: # 返回隐藏状态和输出状态的张量或元组
# 检查位置参数或关键字参数中的 scale 是否被传递
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定义弃用信息,提示用户 scale 参数将被忽略
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用弃用函数,发出警告
deprecate("scale", "1.0.0", deprecation_message)
# 初始化输出状态为一个空元组
output_states = ()
# 将 ResNet 和运动模块进行配对
blocks = zip(self.resnets, self.motion_modules)
# 遍历每对 ResNet 和运动模块
for resnet, motion_module in blocks:
# 如果处于训练模式且启用了梯度检查点
if self.training and self.gradient_checkpointing:
# 定义一个自定义前向传播函数
def create_custom_forward(module):
def custom_forward(*inputs): # 自定义前向函数,接受任意输入
return module(*inputs) # 返回模块的输出
return custom_forward # 返回自定义前向函数
# 如果 PyTorch 版本大于等于 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用检查点机制来节省内存
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 创建的自定义前向函数
hidden_states, # 输入的隐藏状态
temb, # 输入的时间嵌入
use_reentrant=False, # 不使用重入
)
else:
# 在较早版本中也使用检查点机制
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 如果不是训练模式,直接通过 ResNet 处理隐藏状态
hidden_states = resnet(hidden_states, temb)
# 使用运动模块处理当前的隐藏状态
hidden_states = motion_module(hidden_states, num_frames=num_frames)
# 将当前隐藏状态添加到输出状态中
output_states = output_states + (hidden_states,)
# 如果下采样器不为空
if self.downsamplers is not None:
# 遍历每个下采样器
for downsampler in self.downsamplers:
# 通过下采样器处理隐藏状态
hidden_states = downsampler(hidden_states)
# 将下采样后的隐藏状态添加到输出状态中
output_states = output_states + (hidden_states,)
# 返回最终的隐藏状态和输出状态
return hidden_states, output_states
# 初始化方法,用于设置网络的参数
def __init__(
# 输入通道数量
self,
in_channels: int,
# 输出通道数量
out_channels: int,
# 时间嵌入通道数量
temb_channels: int,
# dropout 概率,默认为 0.0
dropout: float = 0.0,
# 网络层数,默认为 1
num_layers: int = 1,
# 每个块中的变换器层数,默认为 1
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 中的 epsilon 值,默认为 1e-6
resnet_eps: float = 1e-6,
# ResNet 时间尺度偏移,默认为 "default"
resnet_time_scale_shift: str = "default",
# ResNet 激活函数,默认为 "swish"
resnet_act_fn: str = "swish",
# ResNet 中的组数,默认为 32
resnet_groups: int = 32,
# 是否在 ResNet 中使用预归一化,默认为 True
resnet_pre_norm: bool = True,
# 注意力头的数量,默认为 1
num_attention_heads: int = 1,
# 交叉注意力维度,默认为 1280
cross_attention_dim: int = 1280,
# 输出缩放因子,默认为 1.0
output_scale_factor: float = 1.0,
# 下采样填充,默认为 1
downsample_padding: int = 1,
# 是否添加下采样层,默认为 True
add_downsample: bool = True,
# 是否使用双交叉注意力,默认为 False
dual_cross_attention: bool = False,
# 是否使用线性投影,默认为 False
use_linear_projection: bool = False,
# 是否仅使用交叉注意力,默认为 False
only_cross_attention: bool = False,
# 是否提升注意力计算精度,默认为 False
upcast_attention: bool = False,
# 注意力类型,默认为 "default"
attention_type: str = "default",
# 时间交叉注意力维度,可选参数
temporal_cross_attention_dim: Optional[int] = None,
# 时间注意力头数量,默认为 8
temporal_num_attention_heads: int = 8,
# 时间序列的最大长度,默认为 32
temporal_max_seq_length: int = 32,
# 时间变换器块中的层数,默认为 1
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# 是否使用双重自注意力,默认为 True
temporal_double_self_attention: bool = True,
# 前向传播方法,定义如何通过模型传递数据
def forward(
# 隐藏状态张量,输入到模型中的主要数据
self,
hidden_states: torch.Tensor,
# 可选的时间嵌入张量
temb: Optional[torch.Tensor] = None,
# 可选的编码器隐藏状态
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可选的注意力掩码
attention_mask: Optional[torch.Tensor] = None,
# 每次处理的帧数,默认为 1
num_frames: int = 1,
# 可选的编码器注意力掩码
encoder_attention_mask: Optional[torch.Tensor] = None,
# 可选的交叉注意力参数
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可选的额外残差连接
additional_residuals: Optional[torch.Tensor] = None,
):
# 检查 cross_attention_kwargs 是否不为空
if cross_attention_kwargs is not None:
# 检查 scale 参数是否存在,若存在则发出警告
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 初始化输出状态为空元组
output_states = ()
# 将自残差网络、注意力模块和运动模块组合成一个列表
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
# 遍历组合后的模块及其索引
for i, (resnet, attn, motion_module) in enumerate(blocks):
# 如果处于训练状态且启用了梯度检查点
if self.training and self.gradient_checkpointing:
# 定义自定义前向传播函数
def create_custom_forward(module, return_dict=None):
# 定义实际的前向传播逻辑
def custom_forward(*inputs):
# 根据 return_dict 的值选择返回方式
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
# 定义检查点参数字典,根据 PyTorch 版本设置
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 使用检查点机制计算隐藏状态
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
# 通过注意力模块处理隐藏状态
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
# 在非训练模式下直接通过残差网络处理隐藏状态
hidden_states = resnet(hidden_states, temb)
# 通过注意力模块处理隐藏状态
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# 通过运动模块处理隐藏状态
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
# 如果是最后一对模块且有额外残差,则将其应用到隐藏状态
if i == len(blocks) - 1 and additional_residuals is not None:
hidden_states = hidden_states + additional_residuals
# 将当前隐藏状态添加到输出状态中
output_states = output_states + (hidden_states,)
# 如果存在下采样模块,则依次应用它们
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
# 将下采样后的隐藏状态添加到输出状态中
output_states = output_states + (hidden_states,)
# 返回最终的隐藏状态和输出状态
return hidden_states, output_states
# 定义一个继承自 nn.Module 的类,用于交叉注意力上采样块
class CrossAttnUpBlockMotion(nn.Module):
# 初始化方法,设置各层的参数
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
prev_output_channel: int, # 前一层输出的通道数
temb_channels: int, # 时间嵌入通道数
resolution_idx: Optional[int] = None, # 分辨率索引,默认为 None
dropout: float = 0.0, # dropout 概率
num_layers: int = 1, # 层数
transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 每个块的变换器层数
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值
resnet_time_scale_shift: str = "default", # ResNet 时间缩放偏移
resnet_act_fn: str = "swish", # ResNet 激活函数
resnet_groups: int = 32, # ResNet 组数
resnet_pre_norm: bool = True, # 是否在前面进行归一化
num_attention_heads: int = 1, # 注意力头的数量
cross_attention_dim: int = 1280, # 交叉注意力的维度
output_scale_factor: float = 1.0, # 输出缩放因子
add_upsample: bool = True, # 是否添加上采样
dual_cross_attention: bool = False, # 是否使用双重交叉注意力
use_linear_projection: bool = False, # 是否使用线性投影
only_cross_attention: bool = False, # 是否仅使用交叉注意力
upcast_attention: bool = False, # 是否上浮注意力
attention_type: str = "default", # 注意力类型
temporal_cross_attention_dim: Optional[int] = None, # 时间交叉注意力维度,默认为 None
temporal_num_attention_heads: int = 8, # 时间注意力头数量
temporal_max_seq_length: int = 32, # 时间序列的最大长度
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 时间块的变换器层数
# 定义前向传播方法
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态张量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 之前隐藏状态的元组
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可选的编码器隐藏状态
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 交叉注意力的可选参数
upsample_size: Optional[int] = None, # 可选的上采样大小
attention_mask: Optional[torch.Tensor] = None, # 可选的注意力掩码
encoder_attention_mask: Optional[torch.Tensor] = None, # 可选的编码器注意力掩码
num_frames: int = 1, # 帧数,默认为 1
# 定义一个继承自 nn.Module 的类,用于上采样块
class UpBlockMotion(nn.Module):
# 初始化方法,设置各层的参数
def __init__(
self,
in_channels: int, # 输入通道数
prev_output_channel: int, # 前一层输出的通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
resolution_idx: Optional[int] = None, # 分辨率索引,默认为 None
dropout: float = 0.0, # dropout 概率
num_layers: int = 1, # 层数
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值
resnet_time_scale_shift: str = "default", # ResNet 时间缩放偏移
resnet_act_fn: str = "swish", # ResNet 激活函数
resnet_groups: int = 32, # ResNet 组数
resnet_pre_norm: bool = True, # 是否在前面进行归一化
output_scale_factor: float = 1.0, # 输出缩放因子
add_upsample: bool = True, # 是否添加上采样
temporal_cross_attention_dim: Optional[int] = None, # 时间交叉注意力维度,默认为 None
temporal_num_attention_heads: int = 8, # 时间注意力头数量
temporal_max_seq_length: int = 32, # 时间序列的最大长度
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 时间块的变换器层数
):
# 调用父类的初始化方法
super().__init__()
# 初始化空列表,用于存放 ResNet 模块
resnets = []
# 初始化空列表,用于存放运动模块
motion_modules = []
# 支持每个时间块的变换层数量为变量
if isinstance(temporal_transformer_layers_per_block, int):
# 将单个整数转换为与层数相同的元组
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
# 检查传入的层数是否与预期一致
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)
# 遍历每层,构建 ResNet 和运动模块
for i in range(num_layers):
# 设定跳过连接的通道数
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 设定当前层的输入通道数
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 添加 ResNetBlock2D 模块到 resnets 列表
resnets.append(
ResnetBlock2D(
# 输入通道数为当前层的输入和跳过连接的通道数之和
in_channels=resnet_in_channels + res_skip_channels,
# 输出通道数设定
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# 小常数以避免除零
eps=resnet_eps,
# 组归一化的组数
groups=resnet_groups,
# Dropout 率
dropout=dropout,
# 时间嵌入的归一化方式
time_embedding_norm=resnet_time_scale_shift,
# 激活函数设定
non_linearity=resnet_act_fn,
# 输出尺度因子
output_scale_factor=output_scale_factor,
# 是否使用预归一化
pre_norm=resnet_pre_norm,
)
)
# 添加 AnimateDiffTransformer3D 模块到 motion_modules 列表
motion_modules.append(
AnimateDiffTransformer3D(
# 注意力头的数量
num_attention_heads=temporal_num_attention_heads,
# 输入通道数
in_channels=out_channels,
# 当前层的变换层数量
num_layers=temporal_transformer_layers_per_block[i],
# 组归一化的组数
norm_num_groups=resnet_groups,
# 跨注意力维度
cross_attention_dim=temporal_cross_attention_dim,
# 是否使用注意力偏置
attention_bias=False,
# 激活函数类型
activation_fn="geglu",
# 位置信息嵌入类型
positional_embeddings="sinusoidal",
# 位置信息嵌入数量
num_positional_embeddings=temporal_max_seq_length,
# 每个注意力头的维度
attention_head_dim=out_channels // temporal_num_attention_heads,
)
)
# 将 ResNet 模块列表转换为 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 将运动模块列表转换为 nn.ModuleList
self.motion_modules = nn.ModuleList(motion_modules)
# 如果需要上采样,则初始化上采样模块
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 否则,设定为 None
self.upsamplers = None
# 设定梯度检查点标志为 False
self.gradient_checkpointing = False
# 保存分辨率索引
self.resolution_idx = resolution_idx
def forward(
# 前向传播方法的参数定义
self,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可选的时间嵌入
temb: Optional[torch.Tensor] = None,
# 上采样大小
upsample_size=None,
# 帧数,默认为 1
num_frames: int = 1,
# 额外的参数
*args,
**kwargs,
# 函数返回类型为 torch.Tensor
) -> torch.Tensor:
# 检查传入参数是否存在或 "scale" 参数是否为非 None
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定义弃用提示信息
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用 deprecate 函数记录弃用警告
deprecate("scale", "1.0.0", deprecation_message)
# 检查 FreeU 是否启用,确保相关属性均不为 None
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
# 将自定义模块打包成元组,方便遍历
blocks = zip(self.resnets, self.motion_modules)
# 遍历每一对 resnet 和 motion_module
for resnet, motion_module in blocks:
# 从隐藏状态元组中弹出最后一个隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
# 更新隐藏状态元组,移除最后一个元素
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 如果启用 FreeU,则仅对前两个阶段进行操作
if is_freeu_enabled:
# 应用 FreeU 函数获取新的隐藏状态
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
# 将当前隐藏状态和残差隐藏状态在维度 1 上拼接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 如果在训练模式并且启用了梯度检查点
if self.training and self.gradient_checkpointing:
# 定义创建自定义前向传播函数
def create_custom_forward(module):
# 定义自定义前向传播的实现
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 如果 torch 版本大于等于 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用检查点机制保存内存
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
# 否则使用旧版检查点机制
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 否则直接通过 resnet 计算隐藏状态
hidden_states = resnet(hidden_states, temb)
# 通过 motion_module 处理隐藏状态,传入帧数
hidden_states = motion_module(hidden_states, num_frames=num_frames)
# 如果存在上采样器,则对每个上采样器进行处理
if self.upsamplers is not None:
for upsampler in self.upsamplers:
# 通过上采样器处理隐藏状态,传入上采样大小
hidden_states = upsampler(hidden_states, upsample_size)
# 返回最终处理后的隐藏状态
return hidden_states
# 定义 UNetMidBlockCrossAttnMotion 类,继承自 nn.Module
class UNetMidBlockCrossAttnMotion(nn.Module):
# 初始化方法,定义类的参数
def __init__(
self,
in_channels: int, # 输入通道数
temb_channels: int, # 时间嵌入通道数
dropout: float = 0.0, # Dropout 率
num_layers: int = 1, # 层数
transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 每个块的变换层数
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值
resnet_time_scale_shift: str = "default", # ResNet 时间尺度偏移
resnet_act_fn: str = "swish", # ResNet 激活函数类型
resnet_groups: int = 32, # ResNet 组数
resnet_pre_norm: bool = True, # 是否进行前置归一化
num_attention_heads: int = 1, # 注意力头数量
output_scale_factor: float = 1.0, # 输出缩放因子
cross_attention_dim: int = 1280, # 交叉注意力维度
dual_cross_attention: bool = False, # 是否使用双重交叉注意力
use_linear_projection: bool = False, # 是否使用线性投影
upcast_attention: bool = False, # 是否上升注意力精度
attention_type: str = "default", # 注意力类型
temporal_num_attention_heads: int = 1, # 时间注意力头数量
temporal_cross_attention_dim: Optional[int] = None, # 时间交叉注意力维度
temporal_max_seq_length: int = 32, # 时间序列最大长度
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 时间块的变换层数
# 前向传播方法,定义输入和输出
def forward(
self,
hidden_states: torch.Tensor, # 隐藏状态的输入张量
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可选的编码器隐藏状态
attention_mask: Optional[torch.Tensor] = None, # 可选的注意力掩码
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 可选的交叉注意力参数
encoder_attention_mask: Optional[torch.Tensor] = None, # 可选的编码器注意力掩码
num_frames: int = 1, # 帧数
# 该函数的返回类型为 torch.Tensor
) -> torch.Tensor:
# 检查交叉注意力参数是否不为 None
if cross_attention_kwargs is not None:
# 如果参数中包含 "scale",发出警告,说明该参数已弃用
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 通过第一个残差网络处理隐藏状态
hidden_states = self.resnets[0](hidden_states, temb)
# 将注意力层、残差网络和运动模块打包在一起
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
# 遍历每个注意力层、残差网络和运动模块
for attn, resnet, motion_module in blocks:
# 如果在训练模式下并且启用了梯度检查点
if self.training and self.gradient_checkpointing:
# 创建自定义前向函数
def create_custom_forward(module, return_dict=None):
# 定义自定义前向函数,接受任意输入
def custom_forward(*inputs):
# 如果返回字典不为 None,使用返回字典调用模块
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
# 否则直接调用模块
return module(*inputs)
return custom_forward
# 根据 PyTorch 版本设置检查点参数
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 调用注意力模块并获取输出的第一个元素
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# 使用梯度检查点对运动模块进行前向传播
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
**ckpt_kwargs,
)
# 使用梯度检查点对残差网络进行前向传播
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
# 在非训练模式下直接调用注意力模块
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# 调用运动模块,传入隐藏状态和帧数
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
# 调用残差网络处理隐藏状态
hidden_states = resnet(hidden_states, temb)
# 返回处理后的隐藏状态
return hidden_states
# 定义一个继承自 nn.Module 的运动模块类
class MotionModules(nn.Module):
# 初始化方法,接收多个参数配置运动模块
def __init__(
self,
in_channels: int, # 输入通道数
layers_per_block: int = 2, # 每个模块块的层数,默认是 2
transformer_layers_per_block: Union[int, Tuple[int]] = 8, # 每个块中的变换层数
num_attention_heads: Union[int, Tuple[int]] = 8, # 注意力头的数量
attention_bias: bool = False, # 是否使用注意力偏差
cross_attention_dim: Optional[int] = None, # 交叉注意力维度
activation_fn: str = "geglu", # 激活函数,默认使用 "geglu"
norm_num_groups: int = 32, # 归一化组的数量
max_seq_length: int = 32, # 最大序列长度
):
# 调用父类初始化方法
super().__init__()
# 初始化运动模块列表
self.motion_modules = nn.ModuleList([])
# 如果变换层数是整数,重复为每个模块块填充
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block
# 检查变换层数与块层数是否匹配
elif len(transformer_layers_per_block) != layers_per_block:
raise ValueError(
f"The number of transformer layers per block must match the number of layers per block, "
f"got {layers_per_block} and {len(transformer_layers_per_block)}"
)
# 遍历每个模块块
for i in range(layers_per_block):
# 向运动模块列表添加 AnimateDiffTransformer3D 实例
self.motion_modules.append(
AnimateDiffTransformer3D(
in_channels=in_channels, # 输入通道数
num_layers=transformer_layers_per_block[i], # 当前块的变换层数
norm_num_groups=norm_num_groups, # 归一化组的数量
cross_attention_dim=cross_attention_dim, # 交叉注意力维度
activation_fn=activation_fn, # 激活函数
attention_bias=attention_bias, # 注意力偏差
num_attention_heads=num_attention_heads, # 注意力头数量
attention_head_dim=in_channels // num_attention_heads, # 每个注意力头的维度
positional_embeddings="sinusoidal", # 使用正弦波的位置嵌入
num_positional_embeddings=max_seq_length, # 位置嵌入的数量
)
)
# 定义一个运动适配器类,结合多个混合类
class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin):
@register_to_config
# 初始化方法,配置多个运动适配器参数
def __init__(
self,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), # 块输出通道
motion_layers_per_block: Union[int, Tuple[int]] = 2, # 每个运动块的层数
motion_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1, # 每个运动块中的变换层数
motion_mid_block_layers_per_block: int = 1, # 中间块的层数
motion_transformer_layers_per_mid_block: Union[int, Tuple[int]] = 1, # 中间块中的变换层数
motion_num_attention_heads: Union[int, Tuple[int]] = 8, # 中间块的注意力头数量
motion_norm_num_groups: int = 32, # 中间块的归一化组数量
motion_max_seq_length: int = 32, # 中间块的最大序列长度
use_motion_mid_block: bool = True, # 是否使用中间块
conv_in_channels: Optional[int] = None, # 输入通道数
):
pass # 前向传播方法,尚未实现
# 定义一个修改后的条件 2D UNet 模型
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
r"""
一个修改后的条件 2D UNet 模型,接收嘈杂样本、条件状态和时间步,返回形状输出。
该模型继承自 [`ModelMixin`]。查看超类文档以获取所有模型的通用方法实现(如下载或保存)。
"""
# 支持梯度检查点
_supports_gradient_checkpointing = True
@register_to_config
# 初始化方法,用于创建类的实例
def __init__(
# 可选参数,样本大小,默认为 None
self,
sample_size: Optional[int] = None,
# 输入通道数,默认为 4
in_channels: int = 4,
# 输出通道数,默认为 4
out_channels: int = 4,
# 下采样块的类型元组
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockMotion", # 第一个下采样块类型
"CrossAttnDownBlockMotion", # 第二个下采样块类型
"CrossAttnDownBlockMotion", # 第三个下采样块类型
"DownBlockMotion", # 第四个下采样块类型
),
# 上采样块的类型元组
up_block_types: Tuple[str, ...] = (
"UpBlockMotion", # 第一个上采样块类型
"CrossAttnUpBlockMotion", # 第二个上采样块类型
"CrossAttnUpBlockMotion", # 第三个上采样块类型
"CrossAttnUpBlockMotion", # 第四个上采样块类型
),
# 块的输出通道数元组
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
# 每个块的层数,默认为 2
layers_per_block: Union[int, Tuple[int]] = 2,
# 下采样填充,默认为 1
downsample_padding: int = 1,
# 中间块的缩放因子,默认为 1
mid_block_scale_factor: float = 1,
# 激活函数类型,默认为 "silu"
act_fn: str = "silu",
# 归一化的组数,默认为 32
norm_num_groups: int = 32,
# 归一化的 epsilon 值,默认为 1e-5
norm_eps: float = 1e-5,
# 交叉注意力的维度,默认为 1280
cross_attention_dim: int = 1280,
# 每个块的变换器层数,默认为 1
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
# 可选参数,反向变换器层数,默认为 None
reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
# 时间变换器的层数,默认为 1
temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
# 可选参数,反向时间变换器层数,默认为 None
reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
# 每个中间块的变换器层数,默认为 None
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
# 每个中间块的时间变换器层数,默认为 1
temporal_transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = 1,
# 是否使用线性投影,默认为 False
use_linear_projection: bool = False,
# 注意力头的数量,默认为 8
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
# 动作最大序列长度,默认为 32
motion_max_seq_length: int = 32,
# 动作注意力头的数量,默认为 8
motion_num_attention_heads: Union[int, Tuple[int, ...]] = 8,
# 可选参数,反向动作注意力头的数量,默认为 None
reverse_motion_num_attention_heads: Optional[Union[int, Tuple[int, ...], Tuple[Tuple[int, ...], ...]]] = None,
# 是否使用动作中间块,默认为 True
use_motion_mid_block: bool = True,
# 中间块的层数,默认为 1
mid_block_layers: int = 1,
# 编码器隐藏层维度,默认为 None
encoder_hid_dim: Optional[int] = None,
# 编码器隐藏层类型,默认为 None
encoder_hid_dim_type: Optional[str] = None,
# 可选参数,附加嵌入类型,默认为 None
addition_embed_type: Optional[str] = None,
# 可选参数,附加时间嵌入维度,默认为 None
addition_time_embed_dim: Optional[int] = None,
# 可选参数,投影类别嵌入的输入维度,默认为 None
projection_class_embeddings_input_dim: Optional[int] = None,
# 可选参数,时间条件投影维度,默认为 None
time_cond_proj_dim: Optional[int] = None,
# 类方法,用于从 UNet2DConditionModel 创建对象
@classmethod
def from_unet2d(
cls,
# UNet2DConditionModel 对象
unet: UNet2DConditionModel,
# 可选的运动适配器,默认为 None
motion_adapter: Optional[MotionAdapter] = None,
# 是否加载权重,默认为 True
load_weights: bool = True,
# 冻结 UNet2DConditionModel 的权重,只保留运动模块可训练,便于微调
def freeze_unet2d_params(self) -> None:
"""Freeze the weights of just the UNet2DConditionModel, and leave the motion modules
unfrozen for fine tuning.
"""
# 冻结所有参数
for param in self.parameters():
# 将参数的 requires_grad 属性设置为 False,禁止梯度更新
param.requires_grad = False
# 解冻运动模块
for down_block in self.down_blocks:
# 获取当前下采样块的运动模块
motion_modules = down_block.motion_modules
for param in motion_modules.parameters():
# 将运动模块参数的 requires_grad 属性设置为 True,允许梯度更新
param.requires_grad = True
for up_block in self.up_blocks:
# 获取当前上采样块的运动模块
motion_modules = up_block.motion_modules
for param in motion_modules.parameters():
# 将运动模块参数的 requires_grad 属性设置为 True,允许梯度更新
param.requires_grad = True
# 检查中间块是否具有运动模块
if hasattr(self.mid_block, "motion_modules"):
# 获取中间块的运动模块
motion_modules = self.mid_block.motion_modules
for param in motion_modules.parameters():
# 将运动模块参数的 requires_grad 属性设置为 True,允许梯度更新
param.requires_grad = True
# 加载运动模块的状态字典
def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None:
# 遍历运动适配器的下采样块
for i, down_block in enumerate(motion_adapter.down_blocks):
# 加载下采样块的运动模块状态字典
self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict())
# 遍历运动适配器的上采样块
for i, up_block in enumerate(motion_adapter.up_blocks):
# 加载上采样块的运动模块状态字典
self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict())
# 支持没有中间块的旧运动模块
if hasattr(self.mid_block, "motion_modules"):
# 加载中间块的运动模块状态字典
self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict())
# 保存运动模块的状态
def save_motion_modules(
self,
save_directory: str,
is_main_process: bool = True,
safe_serialization: bool = True,
variant: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
) -> None:
# 获取当前模型的状态字典
state_dict = self.state_dict()
# 提取所有运动模块的状态
motion_state_dict = {}
for k, v in state_dict.items():
# 筛选出包含 "motion_modules" 的键值对
if "motion_modules" in k:
motion_state_dict[k] = v
# 创建运动适配器实例
adapter = MotionAdapter(
block_out_channels=self.config["block_out_channels"],
motion_layers_per_block=self.config["layers_per_block"],
motion_norm_num_groups=self.config["norm_num_groups"],
motion_num_attention_heads=self.config["motion_num_attention_heads"],
motion_max_seq_length=self.config["motion_max_seq_length"],
use_motion_mid_block=self.config["use_motion_mid_block"],
)
# 加载运动状态字典
adapter.load_state_dict(motion_state_dict)
# 保存适配器的预训练状态
adapter.save_pretrained(
save_directory=save_directory,
is_main_process=is_main_process,
safe_serialization=safe_serialization,
variant=variant,
push_to_hub=push_to_hub,
**kwargs,
)
@property
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 复制的属性
# 定义一个方法,返回模型中所有注意力处理器的字典
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
返回值:
`dict` 类型的注意力处理器: 包含模型中所有注意力处理器的字典,
按照其权重名称索引。
"""
# 初始化一个空字典,用于存储注意力处理器
processors = {}
# 定义一个递归函数,用于添加注意力处理器
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 检查模块是否具有 'get_processor' 方法
if hasattr(module, "get_processor"):
# 将处理器添加到字典中,键为处理器名称
processors[f"{name}.processor"] = module.get_processor()
# 遍历模块的所有子模块
for sub_name, child in module.named_children():
# 递归调用,继续添加子模块的处理器
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回处理器字典
return processors
# 遍历当前对象的所有子模块
for name, module in self.named_children():
# 调用递归函数添加所有处理器
fn_recursive_add_processors(name, module, processors)
# 返回最终的处理器字典
return processors
# 从 diffusers.models.unets.unet_2d_condition 中复制的方法,用于设置注意力处理器
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
设置用于计算注意力的注意力处理器。
参数:
processor (`dict` of `AttentionProcessor` 或仅 `AttentionProcessor`):
实例化的处理器类或处理器类的字典,将被设置为**所有** `Attention` 层的处理器。
如果 `processor` 是字典,键需要定义相应的交叉注意力处理器的路径。
当设置可训练的注意力处理器时,强烈推荐这样做。
"""
# 获取当前注意力处理器字典的键数量
count = len(self.attn_processors.keys())
# 检查传入的处理器字典长度是否与注意力层数量匹配
if isinstance(processor, dict) and len(processor) != count:
# 如果不匹配,抛出错误
raise ValueError(
f"传入了处理器字典,但处理器数量 {len(processor)} 与"
f" 注意力层数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
)
# 定义一个递归函数,用于设置注意力处理器
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 检查模块是否具有 'set_processor' 方法
if hasattr(module, "set_processor"):
# 如果处理器不是字典,直接设置处理器
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 从字典中弹出对应的处理器并设置
module.set_processor(processor.pop(f"{name}.processor"))
# 遍历模块的所有子模块
for sub_name, child in module.named_children():
# 递归调用,继续设置子模块的处理器
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍历当前对象的所有子模块
for name, module in self.named_children():
# 调用递归函数设置所有处理器
fn_recursive_attn_processor(name, module, processor)
# 定义一个方法以启用前向分块处理
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
设置注意力处理器以使用[前馈分块](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。
参数:
chunk_size (`int`, *可选*):
前馈层的块大小。如果未指定,将单独对维度为`dim`的每个张量运行前馈层。
dim (`int`, *可选*, 默认为`0`):
前馈计算应分块的维度。选择dim=0(批次)或dim=1(序列长度)。
"""
# 检查dim参数是否在有效范围内(0或1)
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# 默认块大小为1
chunk_size = chunk_size or 1
# 定义递归前馈函数以设置模块的分块前馈处理
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模块有set_chunk_feed_forward属性,设置块大小和维度
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遍历模块的子模块
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 遍历当前对象的子模块,应用递归前馈函数
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# 定义一个方法以禁用前向分块处理
def disable_forward_chunking(self) -> None:
# 定义递归前馈函数以设置模块的分块前馈处理为None
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模块有set_chunk_feed_forward属性,设置块大小和维度为None
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遍历模块的子模块
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 遍历当前对象的子模块,应用递归前馈函数
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
# 从diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor复制的方法
def set_default_attn_processor(self) -> None:
"""
禁用自定义注意力处理器并设置默认的注意力实现。
"""
# 如果所有注意力处理器都是ADDED_KV_ATTENTION_PROCESSORS类型
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 设置处理器为AttnAddedKVProcessor
processor = AttnAddedKVProcessor()
# 如果所有注意力处理器都是CROSS_ATTENTION_PROCESSORS类型
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 设置处理器为AttnProcessor
processor = AttnProcessor()
else:
# 抛出错误,表示不能在不匹配的注意力处理器类型下调用该方法
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
# 设置当前对象的注意力处理器
self.set_attn_processor(processor)
# 定义一个方法以设置模块的梯度检查点
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
# 检查模块是否为特定类型,如果是则设置其梯度检查点属性
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
module.gradient_checkpointing = value
# 从diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu复制的方法
# 启用 FreeU 机制,接受四个浮点型缩放因子作为参数
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
# 文档字符串,描述该方法的作用及参数含义
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
# 遍历上采样块,并为每个块设置缩放因子
for i, upsample_block in enumerate(self.up_blocks):
# 为上采样块设置阶段1的缩放因子
setattr(upsample_block, "s1", s1)
# 为上采样块设置阶段2的缩放因子
setattr(upsample_block, "s2", s2)
# 为上采样块设置阶段1的主干特征缩放因子
setattr(upsample_block, "b1", b1)
# 为上采样块设置阶段2的主干特征缩放因子
setattr(upsample_block, "b2", b2)
# 禁用 FreeU 机制
def disable_freeu(self) -> None:
# 文档字符串,描述该方法的作用
"""Disables the FreeU mechanism."""
# 定义 FreeU 相关的键名集合
freeu_keys = {"s1", "s2", "b1", "b2"}
# 遍历上采样块
for i, upsample_block in enumerate(self.up_blocks):
# 遍历 FreeU 键名
for k in freeu_keys:
# 检查上采样块是否具有该属性或该属性是否不为 None
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
# 将上采样块的该属性设置为 None
setattr(upsample_block, k, None)
# 启用融合的 QKV 投影
def fuse_qkv_projections(self):
# 文档字符串,描述该方法的作用
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
# 初始化原始注意力处理器为 None
self.original_attn_processors = None
# 遍历注意力处理器
for _, attn_processor in self.attn_processors.items():
# 检查注意力处理器类名中是否包含 "Added"
if "Added" in str(attn_processor.__class__.__name__):
# 抛出异常,说明不支持该操作
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# 保存原始的注意力处理器
self.original_attn_processors = self.attn_processors
# 遍历所有模块
for module in self.modules():
# 检查模块是否为 Attention 类型
if isinstance(module, Attention):
# 融合投影
module.fuse_projections(fuse=True)
# 设置融合后的注意力处理器
self.set_attn_processor(FusedAttnProcessor2_0())
# 解融合 QKV 投影的方法(省略具体实现)
# 定义一个禁用融合 QKV 投影的方法
def unfuse_qkv_projections(self):
"""如果启用了,禁用融合 QKV 投影。
<Tip warning={true}>
此 API 是 🧪 实验性。
</Tip>
"""
# 检查原始注意力处理器是否不为 None
if self.original_attn_processors is not None:
# 设置当前注意力处理器为原始的注意力处理器
self.set_attn_processor(self.original_attn_processors)
# 定义前向传播方法,接收多个参数
def forward(
self,
# 输入样本张量
sample: torch.Tensor,
# 时间步,可以是张量、浮点数或整数
timestep: Union[torch.Tensor, float, int],
# 编码器隐藏状态张量
encoder_hidden_states: torch.Tensor,
# 可选的时间步条件张量
timestep_cond: Optional[torch.Tensor] = None,
# 可选的注意力掩码张量
attention_mask: Optional[torch.Tensor] = None,
# 可选的交叉注意力参数字典
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可选的附加条件参数字典
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
# 可选的下块附加残差元组
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
# 可选的中间块附加残差张量
mid_block_additional_residual: Optional[torch.Tensor] = None,
# 是否返回字典格式的结果,默认为 True
return_dict: bool = True,