CogView3 & CogView-3Plus 微调代码源码解析(二)
.\cogview3-finetune\sat\sgm\models\__init__.py
# 从同一模块导入 AutoencodingEngine 类,用于后续的自动编码器操作
from .autoencoder import AutoencodingEngine
# 注释文本(可能是无关信息或标识符)
#XuDwndGaCFo
.\cogview3-finetune\sat\sgm\modules\attention.py
# 导入数学库
import math
# 从 inspect 模块导入 isfunction 函数,用于检查对象是否为函数
from inspect import isfunction
# 导入 Any 和 Optional 类型
from typing import Any, Optional
# 导入 PyTorch 库
import torch
# 导入 PyTorch 的功能性模块
import torch.nn.functional as F
# 从 einops 库导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 导入版本管理工具
from packaging import version
# 导入 PyTorch 的神经网络模块
from torch import nn
# 检查 PyTorch 版本是否大于或等于 2.0.0
if version.parse(torch.__version__) >= version.parse("2.0.0"):
# 设置 SDP_IS_AVAILABLE 为 True,表示 SDP 后端可用
SDP_IS_AVAILABLE = True
# 从 PyTorch 导入 SDPBackend 和 sdp_kernel
from torch.backends.cuda import SDPBackend, sdp_kernel
# 定义后端映射字典,根据不同的后端配置相应的选项
BACKEND_MAP = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
}
else:
# 从上下文管理库导入 nullcontext
from contextlib import nullcontext
# 设置 SDP_IS_AVAILABLE 为 False,表示 SDP 后端不可用
SDP_IS_AVAILABLE = False
# 将 sdp_kernel 设置为 nullcontext
sdp_kernel = nullcontext
# 打印提示信息,告知用户当前 PyTorch 版本不支持 SDP 后端
print(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
)
# 尝试导入 xformers 和 xformers.ops
try:
import xformers
import xformers.ops
# 如果导入成功,设置 XFORMERS_IS_AVAILABLE 为 True
XFORMERS_IS_AVAILABLE = True
# 如果导入失败,设置 XFORMERS_IS_AVAILABLE 为 False,并打印提示信息
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
# 从 diffusionmodules.util 模块导入 checkpoint 函数
from .diffusionmodules.util import checkpoint
# 定义 exists 函数,检查输入值是否存在
def exists(val):
return val is not None
# 定义 uniq 函数,返回数组中唯一元素的键
def uniq(arr):
return {el: True for el in arr}.keys()
# 定义 default 函数,如果 val 存在则返回它,否则返回默认值
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# 定义 max_neg_value 函数,返回给定张量类型的最大负值
def max_neg_value(t):
return -torch.finfo(t.dtype).max
# 定义 init_ 函数,初始化张量
def init_(tensor):
# 获取张量的最后一维的大小
dim = tensor.shape[-1]
# 计算标准差
std = 1 / math.sqrt(dim)
# 在区间 [-std, std] 内均匀初始化张量
tensor.uniform_(-std, std)
return tensor
# 定义 GEGLU 类,继承自 nn.Module
class GEGLU(nn.Module):
# 初始化方法,设置输入和输出维度
def __init__(self, dim_in, dim_out):
super().__init__()
# 创建一个线性投影层,将输入维度映射到两倍的输出维度
self.proj = nn.Linear(dim_in, dim_out * 2)
# 前向传播方法
def forward(self, x):
# 将输入通过投影层,分割为 x 和 gate
x, gate = self.proj(x).chunk(2, dim=-1)
# 返回 x 与 gate 的 GELU 激活后的乘积
return x * F.gelu(gate)
# 定义 FeedForward 类,继承自 nn.Module
class FeedForward(nn.Module):
# 初始化方法,设置维度、输出维度、乘法因子、是否使用 GEGLU 和 dropout 率
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
# 计算内部维度
inner_dim = int(dim * mult)
# 如果 dim_out 未定义,使用 dim 作为默认值
dim_out = default(dim_out, dim)
# 根据是否使用 GEGLU 创建输入投影层
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
# 定义网络结构
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
# 前向传播方法
def forward(self, x):
# 通过网络结构处理输入
return self.net(x)
# 定义 zero_module 函数,清零模块的参数并返回该模块
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
# 遍历模块的所有参数
for p in module.parameters():
# 将参数的梯度断开并清零
p.detach().zero_()
return module
# 定义 Normalize 函数,接收输入通道数
def Normalize(in_channels):
# 返回一个 GroupNorm 实例,用于对输入进行分组归一化
return torch.nn.GroupNorm(
# 设置分组数量为 32
num_groups=32,
# 设置输入通道数量
num_channels=in_channels,
# 设置一个小的 epsilon 值以避免除零错误
eps=1e-6,
# 设定 affine 为 True,以便进行可学习的仿射变换
affine=True
)
# 定义线性注意力机制的类,继承自 nn.Module
class LinearAttention(nn.Module):
# 初始化方法,设置参数维度和头数
def __init__(self, dim, heads=4, dim_head=32):
# 调用父类构造函数
super().__init__()
# 设置头数
self.heads = heads
# 计算隐藏维度
hidden_dim = dim_head * heads
# 定义输入到 QKV 的卷积层,输出通道为 hidden_dim 的三倍
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
# 定义输出卷积层,输出通道为原始维度
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
# 前向传播方法
def forward(self, x):
# 获取输入的批次大小、通道数、高度和宽度
b, c, h, w = x.shape
# 将输入通过 QKV 卷积层
qkv = self.to_qkv(x)
# 将 QKV 的输出重排以分开 Q、K、V
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
# 对 K 进行 softmax 归一化
k = k.softmax(dim=-1)
# 计算上下文向量,通过爱因斯坦求和约定
context = torch.einsum("bhdn,bhen->bhde", k, v)
# 使用上下文向量和 Q 计算输出
out = torch.einsum("bhde,bhdn->bhen", context, q)
# 重排输出以恢复到原始形状
out = rearrange(
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
)
# 返回最终输出
return self.to_out(out)
# 定义空间自注意力机制的类,继承自 nn.Module
class SpatialSelfAttention(nn.Module):
# 初始化方法,设置输入通道
def __init__(self, in_channels):
# 调用父类构造函数
super().__init__()
# 存储输入通道数
self.in_channels = in_channels
# 创建归一化层
self.norm = Normalize(in_channels)
# 创建 Q、K、V 的卷积层,输出通道与输入通道相同
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
# 创建输出的卷积层
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
# 前向传播方法
def forward(self, x):
# 初始化 h_ 为输入 x
h_ = x
# 对 h_ 进行归一化处理
h_ = self.norm(h_)
# 计算 Q、K、V
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# 计算注意力权重
b, c, h, w = q.shape
# 将 Q 重排为 (batch_size, h*w, c) 的形状
q = rearrange(q, "b c h w -> b (h w) c")
# 将 K 重排为 (batch_size, c, h*w) 的形状
k = rearrange(k, "b c h w -> b c (h w)")
# 计算 Q 和 K 的点积以获得权重
w_ = torch.einsum("bij,bjk->bik", q, k)
# 对权重进行缩放
w_ = w_ * (int(c) ** (-0.5))
# 对权重进行 softmax 归一化
w_ = torch.nn.functional.softmax(w_, dim=2)
# 处理 V
v = rearrange(v, "b c h w -> b c (h w)")
# 重排权重 w_ 为 (batch_size, h*w, h*w) 的形状
w_ = rearrange(w_, "b i j -> b j i")
# 计算 h_,通过 V 和权重相乘
h_ = torch.einsum("bij,bjk->bik", v, w_)
# 将 h_ 重排回原始形状
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
# 通过 proj_out 进行最终的线性变换
h_ = self.proj_out(h_)
# 返回输入与处理后的 h_ 的和
return x + h_
# 定义交叉注意力机制的类,继承自 nn.Module
class CrossAttention(nn.Module):
# 初始化方法,设置查询、上下文维度和其他参数
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
backend=None,
):
# 调用父类构造函数
super().__init__()
# 计算内部维度
inner_dim = dim_head * heads
# 如果没有提供上下文维度,默认使用查询维度
context_dim = default(context_dim, query_dim)
# 设置缩放因子
self.scale = dim_head**-0.5
# 设置头数
self.heads = heads
# 定义 Q、K、V 的线性变换
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
# 定义输出层,包括线性变换和 dropout
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
# 存储后端
self.backend = backend
# 定义前向传播函数,接收输入数据及相关参数
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
# 获取注意力头的数量
h = self.heads
# 如果有额外的 tokens
if additional_tokens is not None:
# 获取输出序列开始时的掩码 token 数量
n_tokens_to_mask = additional_tokens.shape[1]
# 将额外的 token 添加到输入数据前
x = torch.cat([additional_tokens, x], dim=1)
# 通过线性变换生成查询向量
q = self.to_q(x)
# 使用默认值或输入作为上下文
context = default(context, x)
# 通过线性变换生成键向量
k = self.to_k(context)
# 通过线性变换生成值向量
v = self.to_v(context)
# 如果需要进行跨帧注意力
if n_times_crossframe_attn_in_self:
# 验证输入批次大小可以被跨帧次数整除
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
# 计算每次跨帧的批次大小
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
# 重复键向量以适应跨帧注意力
k = repeat(
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
)
# 重复值向量以适应跨帧注意力
v = repeat(
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
)
# 将查询、键、值向量重排为适合多头注意力的形状
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
## old
"""
# 计算查询与键之间的相似度,并进行缩放
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
# 删除查询和键以节省内存
del q, k
# 如果存在掩码
if exists(mask):
# 将掩码重排为适合的形状
mask = rearrange(mask, 'b ... -> b (...)')
# 获取相似度的最大负值
max_neg_value = -torch.finfo(sim.dtype).max
# 重复掩码以适应多头
mask = repeat(mask, 'b j -> (b h) () j', h=h)
# 使用掩码填充相似度矩阵
sim.masked_fill_(~mask, max_neg_value)
# 应用 softmax 计算注意力权重
sim = sim.softmax(dim=-1)
# 使用注意力权重加权值向量,生成输出
out = einsum('b i j, b j d -> b i d', sim, v)
"""
## new
# 使用指定的后端进行缩放点积注意力计算
with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask
) # 默认缩放因子为 dim_head ** -0.5
# 删除查询、键和值向量以释放内存
del q, k, v
# 将输出重排为适合最终输出的形状
out = rearrange(out, "b h n d -> b n (h d)", h=h)
# 如果有额外的 tokens
if additional_tokens is not None:
# 移除额外的 token
out = out[:, n_tokens_to_mask:]
# 返回最终的输出结果
return self.to_out(out)
# 定义一个内存高效的交叉注意力模块,继承自 nn.Module
class MemoryEfficientCrossAttention(nn.Module):
# 初始化方法,接受查询维度、上下文维度、头数、每个头的维度和丢弃率等参数
def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
):
# 调用父类的初始化方法
super().__init__()
# 计算每个头的内部维度,等于头数乘以每个头的维度
inner_dim = dim_head * heads
# 如果上下文维度未提供,则将其设置为查询维度
context_dim = default(context_dim, query_dim)
# 保存头数和每个头的维度到实例变量
self.heads = heads
self.dim_head = dim_head
# 创建线性层,将查询输入转换为内部维度,不使用偏置
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
# 创建线性层,将上下文输入转换为内部维度,不使用偏置
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
# 创建线性层,将上下文输入转换为内部维度,不使用偏置
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
# 创建一个顺序容器,包含将内部维度转换回查询维度的线性层和丢弃层
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
# 初始化注意力操作为 None
self.attention_op: Optional[Any] = None
# 定义前向传播方法,接受输入、上下文、掩码和其他可选参数
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
# 检查是否提供了额外的令牌
if additional_tokens is not None:
# 获取输出序列开头的被遮掩令牌数量
n_tokens_to_mask = additional_tokens.shape[1]
# 将额外的令牌与当前输入合并
x = torch.cat([additional_tokens, x], dim=1)
# 将输入转换为查询向量
q = self.to_q(x)
# 使用默认值或输入作为上下文
context = default(context, x)
# 将上下文转换为键向量
k = self.to_k(context)
# 将上下文转换为值向量
v = self.to_v(context)
# 检查是否需要进行跨帧注意力的重新编程
if n_times_crossframe_attn_in_self:
# 进行跨帧注意力的重新编程,参考 https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
# k的维度处理,使用每n_times_crossframe_attn_in_self帧的一个
k = repeat(
k[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
# v的维度处理,使用每n_times_crossframe_attn_in_self帧的一个
v = repeat(
v[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
# 获取批次大小和特征维度
b, _, _ = q.shape
# 对 q, k, v 进行维度调整和重塑
q, k, v = map(
lambda t: t.unsqueeze(3) # 在最后一维添加一个新维度
.reshape(b, t.shape[1], self.heads, self.dim_head) # 重塑为(batch, seq_len, heads, dim_head)
.permute(0, 2, 1, 3) # 重新排列维度为(batch, heads, seq_len, dim_head)
.reshape(b * self.heads, t.shape[1], self.dim_head) # 再次重塑为(batch * heads, seq_len, dim_head)
.contiguous(), # 确保内存连续性
(q, k, v), # 对 q, k, v 进行相同处理
)
# 实际计算注意力,这个过程是最不可或缺的
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
# TODO: 将这个直接用作注意力操作中的偏置
if exists(mask):
# 如果存在遮掩,抛出未实现异常
raise NotImplementedError
# 对输出进行维度调整,适配最终输出形状
out = (
out.unsqueeze(0) # 在最前面添加一个新维度
.reshape(b, self.heads, out.shape[1], self.dim_head) # 重塑为(batch, heads, seq_len, dim_head)
.permute(0, 2, 1, 3) # 重新排列维度为(batch, seq_len, heads, dim_head)
.reshape(b, out.shape[1], self.heads * self.dim_head) # 再次重塑为(batch, seq_len, heads * dim_head)
)
# 如果有额外的令牌,则移除它们
if additional_tokens is not None:
out = out[:, n_tokens_to_mask:] # 切除被遮掩的令牌部分
# 将输出转换为最终的输出格式
return self.to_out(out)
# 定义基本的变换器模块类,继承自 nn.Module
class BasicTransformerBlock(nn.Module):
# 定义可用的注意力模式,映射到对应的类
ATTENTION_MODES = {
"softmax": CrossAttention, # 普通注意力
"softmax-xformers": MemoryEfficientCrossAttention, # 记忆高效注意力
}
# 初始化方法,设置模型的参数
def __init__(
self,
dim, # 输入的维度
n_heads, # 注意力头的数量
d_head, # 每个注意力头的维度
dropout=0.0, # dropout 比例
context_dim=None, # 上下文的维度
gated_ff=True, # 是否使用门控前馈网络
checkpoint=True, # 是否启用检查点
disable_self_attn=False, # 是否禁用自注意力
attn_mode="softmax", # 注意力模式,默认为 softmax
sdp_backend=None, # 后端配置
):
# 调用父类初始化方法
super().__init__()
# 检查给定的注意力模式是否有效
assert attn_mode in self.ATTENTION_MODES
# 如果使用的注意力模式不支持,回退到默认模式
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
print(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
attn_mode = "softmax" # 回退到 softmax 模式
# 如果使用普通注意力且不支持,给出提示并调整模式
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
print(
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
)
# 确保已安装 xformers
if not XFORMERS_IS_AVAILABLE:
assert (
False
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers" # 使用 xformers 模式
# 根据最终的注意力模式选择对应的类
attn_cls = self.ATTENTION_MODES[attn_mode]
# 检查 PyTorch 版本是否支持指定的后端
if version.parse(torch.__version__) >= version.parse("2.0.0"):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
else:
assert sdp_backend is None # 对于旧版本,后端必须为 None
self.disable_self_attn = disable_self_attn # 保存禁用自注意力的设置
# 初始化第一个注意力层
self.attn1 = attn_cls(
query_dim=dim, # 查询的维度
heads=n_heads, # 注意力头数量
dim_head=d_head, # 每个头的维度
dropout=dropout, # dropout 比例
context_dim=context_dim if self.disable_self_attn else None, # 上下文维度
backend=sdp_backend, # 后端配置
) # 如果未禁用自注意力,则为自注意力层
# 初始化前馈网络
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) # 前馈网络配置
# 初始化第二个注意力层
self.attn2 = attn_cls(
query_dim=dim, # 查询的维度
context_dim=context_dim, # 上下文维度
heads=n_heads, # 注意力头数量
dim_head=d_head, # 每个头的维度
dropout=dropout, # dropout 比例
backend=sdp_backend, # 后端配置
) # 如果上下文维度为 None,则为自注意力层
# 初始化层归一化层
self.norm1 = nn.LayerNorm(dim) # 第一个归一化层
self.norm2 = nn.LayerNorm(dim) # 第二个归一化层
self.norm3 = nn.LayerNorm(dim) # 第三个归一化层
self.checkpoint = checkpoint # 保存检查点设置
# 如果启用检查点,输出相关信息(代码暂时注释掉)
# if self.checkpoint:
# print(f"{self.__class__.__name__} is using checkpointing")
# 前向传播方法,定义输入和上下文的处理
def forward(
self, x, # 输入数据
context=None, # 上下文数据
additional_tokens=None, # 额外的 token
n_times_crossframe_attn_in_self=0 # 跨帧自注意力的次数
):
# 创建一个字典 kwargs,初始包含键 "x" 和参数 x 的值
kwargs = {"x": x}
# 如果 context 参数不为 None,则将其添加到 kwargs 字典中
if context is not None:
kwargs.update({"context": context})
# 如果 additional_tokens 参数不为 None,则将其添加到 kwargs 字典中
if additional_tokens is not None:
kwargs.update({"additional_tokens": additional_tokens})
# 如果 n_times_crossframe_attn_in_self 为真,则将其添加到 kwargs 字典中
if n_times_crossframe_attn_in_self:
kwargs.update(
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
)
# 返回调用 checkpoint 函数,传入 _forward 方法和相关参数
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(
# 定义 _forward 方法,接收 x、context、additional_tokens 和 n_times_crossframe_attn_in_self 参数
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
):
# 对 x 进行规范化,然后通过 attn1 方法进行自注意力计算,并根据条件选择 context 和其他参数
x = (
self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn
else 0,
)
+ x # 将 self.attn1 的输出与原始 x 相加
)
# 继续对 x 进行规范化,通过 attn2 方法进行自注意力计算
x = (
self.attn2(
self.norm2(x), context=context, additional_tokens=additional_tokens
)
+ x # 将 self.attn2 的输出与当前 x 相加
)
# 对 x 进行规范化,然后通过前馈网络 ff 处理,再与原始 x 相加
x = self.ff(self.norm3(x)) + x
# 返回处理后的 x
return x
# 定义基本的单层变换器块类,继承自 nn.Module
class BasicTransformerSingleLayerBlock(nn.Module):
# 定义不同的注意力模式及其对应的类
ATTENTION_MODES = {
"softmax": CrossAttention, # 标准注意力
"softmax-xformers": MemoryEfficientCrossAttention # 针对 A100s 的优化版本,速度可能略慢
# (todo 可能依赖于 head_dim,需检查,对于 dim!=[16,32,64,128] 时退回到半优化内核)
}
# 初始化方法,设置基本参数
def __init__(
self,
dim, # 特征维度
n_heads, # 注意力头数
d_head, # 每个注意力头的维度
dropout=0.0, # 丢弃率
context_dim=None, # 上下文维度(可选)
gated_ff=True, # 是否使用门控前馈网络
checkpoint=True, # 是否启用检查点
attn_mode="softmax", # 注意力模式
):
# 调用父类构造函数
super().__init__()
# 确保所选的注意力模式在定义的模式中
assert attn_mode in self.ATTENTION_MODES
# 获取对应的注意力类
attn_cls = self.ATTENTION_MODES[attn_mode]
# 初始化注意力层
self.attn1 = attn_cls(
query_dim=dim, # 查询维度
heads=n_heads, # 头数
dim_head=d_head, # 每头的维度
dropout=dropout, # 丢弃率
context_dim=context_dim, # 上下文维度
)
# 初始化前馈网络
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
# 初始化层归一化
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
# 设置检查点标志
self.checkpoint = checkpoint
# 前向传播方法
def forward(self, x, context=None):
# 使用检查点机制来进行前向传播
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
# 实际的前向传播实现
def _forward(self, x, context=None):
# 通过注意力层进行处理并添加残差连接
x = self.attn1(self.norm1(x), context=context) + x
# 通过前馈网络处理并添加残差连接
x = self.ff(self.norm2(x)) + x
# 返回处理后的结果
return x
# 定义空间变换器类,继承自 nn.Module
class SpatialTransformer(nn.Module):
"""
适用于图像数据的变换器块。
首先,将输入(即嵌入)投影并重塑为 b, t, d 形状。
然后应用标准的变换器操作。
最后,重塑为图像。
新增:使用线性层提高效率,而不是 1x1 卷积。
"""
# 初始化方法,设置变换器块的基本参数
def __init__(
self,
in_channels, # 输入通道数
n_heads, # 注意力头数
d_head, # 每个头的维度
depth=1, # 变换器块的深度
dropout=0.0, # 丢弃率
context_dim=None, # 上下文维度(可选)
disable_self_attn=False, # 是否禁用自注意力
use_linear=False, # 是否使用线性层
attn_type="softmax", # 注意力类型
use_checkpoint=True, # 是否使用检查点
# sdp_backend=SDPBackend.FLASH_ATTENTION # 可选的 SDP 后端
sdp_backend=None, # SDP 后端设置(默认值为 None)
# 初始化类,调用父类构造函数
):
super().__init__()
# 打印构造信息(已注释)
# print(
# f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
# )
# 从 omegaconf 导入 ListConfig 类
from omegaconf import ListConfig
# 如果 context_dim 存在且不是列表或 ListConfig 类型
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
# 将 context_dim 转换为列表
context_dim = [context_dim]
# 如果 context_dim 存在且是列表
if exists(context_dim) and isinstance(context_dim, list):
# 检查 depth 是否与 context_dim 的长度匹配
if depth != len(context_dim):
# 打印警告信息(已注释)
# print(
# f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
# f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
# )
# 确保所有 context_dim 元素相同
assert all(
map(lambda x: x == context_dim[0], context_dim)
), "need homogenous context_dim to match depth automatically"
# 如果不一致,设置 context_dim 为相同值的列表
context_dim = depth * [context_dim[0]]
# 如果 context_dim 为 None
elif context_dim is None:
# 创建与 depth 长度相同的 None 列表
context_dim = [None] * depth
# 保存输入通道数
self.in_channels = in_channels
# 计算内部维度
inner_dim = n_heads * d_head
# 归一化层
self.norm = Normalize(in_channels)
# 如果不使用线性层
if not use_linear:
# 使用卷积层进行输入投影
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
else:
# 使用线性层进行输入投影
self.proj_in = nn.Linear(in_channels, inner_dim)
# 创建变压器模块列表
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
attn_mode=attn_type,
checkpoint=use_checkpoint,
sdp_backend=sdp_backend,
)
# 遍历深度范围,生成多个变压器块
for d in range(depth)
]
)
# 如果不使用线性层
if not use_linear:
# 使用零初始化卷积层进行输出投影
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
# 使用零初始化线性层进行输出投影(已注释)
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
# 保存是否使用线性层的标志
self.use_linear = use_linear
# 定义前向传播函数,接收输入 x 和可选的上下文 context
def forward(self, x, context=None):
# 注意:如果没有提供上下文,交叉注意力默认为自注意力
if not isinstance(context, list):
# 将上下文包装为列表,方便后续处理
context = [context]
# 获取输入张量的形状:批量大小 b,通道数 c,高 h,宽 w
b, c, h, w = x.shape
# 保存输入张量的原始值以便后续使用
x_in = x
# 对输入进行归一化处理
x = self.norm(x)
# 如果不使用线性变换,则进行投影变换
if not self.use_linear:
x = self.proj_in(x)
# 重新排列张量的维度,将其从 (b, c, h, w) 变为 (b, h*w, c)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
# 如果使用线性变换,则再次进行投影变换
if self.use_linear:
x = self.proj_in(x)
# 遍历所有的变换块
for i, block in enumerate(self.transformer_blocks):
# 如果不是第一个块且上下文长度为1,则使用同一个上下文
if i > 0 and len(context) == 1:
i = 0 # 每个块使用相同的上下文
# 将输入传入当前变换块,并使用相应的上下文
x = block(x, context=context[i])
# 如果使用线性变换,则进行输出投影变换
if self.use_linear:
x = self.proj_out(x)
# 重新排列张量的维度,将其从 (b, h*w, c) 变为 (b, c, h, w)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
# 如果不使用线性变换,则进行输出投影变换
if not self.use_linear:
x = self.proj_out(x)
# 返回处理后的张量与原始输入的和
return x + x_in
.\cogview3-finetune\sat\sgm\modules\autoencoding\losses\__init__.py
# 导入类型提示 Any 和 Union
from typing import Any, Union
# 导入 PyTorch 及其神经网络模块
import torch
import torch.nn as nn
# 从 einops 导入 rearrange 函数
from einops import rearrange
# 导入自定义工具函数和类
from ....util import default, instantiate_from_config
# 从 lpips 库导入 LPIPS 损失类
from ..lpips.loss.lpips import LPIPS
# 从 lpips 模型导入 NLayerDiscriminator 和权重初始化函数
from ..lpips.model.model import NLayerDiscriminator, weights_init
# 从 vqperceptual 模块导入两种损失函数
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
# 定义 adopt_weight 函数,调整权重值
def adopt_weight(weight, global_step, threshold=0, value=0.0):
# 如果全局步数小于阈值,则将权重设为给定值
if global_step < threshold:
weight = value
# 返回调整后的权重
return weight
# 定义 LatentLPIPS 类,继承自 nn.Module
class LatentLPIPS(nn.Module):
# 初始化方法,设置相关参数
def __init__(
self,
decoder_config,
perceptual_weight=1.0,
latent_weight=1.0,
scale_input_to_tgt_size=False,
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
):
# 调用父类构造方法
super().__init__()
# 设置输入大小缩放标志
self.scale_input_to_tgt_size = scale_input_to_tgt_size
self.scale_tgt_to_input_size = scale_tgt_to_input_size
# 初始化解码器
self.init_decoder(decoder_config)
# 初始化感知损失模型,并设置为评估模式
self.perceptual_loss = LPIPS().eval()
# 设置感知损失和潜在损失的权重
self.perceptual_weight = perceptual_weight
self.latent_weight = latent_weight
# 设置对输入的感知权重
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
# 定义初始化解码器的方法
def init_decoder(self, config):
# 从配置实例化解码器
self.decoder = instantiate_from_config(config)
# 如果解码器有 encoder 属性,则删除该属性
if hasattr(self.decoder, "encoder"):
del self.decoder.encoder
# 定义前向传播函数,接收潜在输入、潜在预测、图像输入和数据集切分信息
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
# 初始化一个字典用于记录日志信息
log = dict()
# 计算潜在输入与潜在预测之间的均方差损失
loss = (latent_inputs - latent_predictions) ** 2
# 将均方差损失的平均值添加到日志中,使用切分名称作为键
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
# 初始化图像重建变量
image_reconstructions = None
# 如果感知损失权重大于0,则进行感知损失的计算
if self.perceptual_weight > 0.0:
# 解码潜在预测生成图像重建
image_reconstructions = self.decoder.decode(latent_predictions)
# 解码潜在输入生成目标图像
image_targets = self.decoder.decode(latent_inputs)
# 计算感知损失
perceptual_loss = self.perceptual_loss(
image_targets.contiguous(), image_reconstructions.contiguous()
)
# 综合潜在损失和感知损失,更新总损失
loss = (
self.latent_weight * loss.mean()
+ self.perceptual_weight * perceptual_loss.mean()
)
# 将感知损失的平均值添加到日志中
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
# 如果感知权重在输入上大于0,则进行相应的处理
if self.perceptual_weight_on_inputs > 0.0:
# 如果重建图像为空,则解码潜在预测生成图像重建
image_reconstructions = default(
image_reconstructions, self.decoder.decode(latent_predictions)
)
# 如果需要将输入图像缩放到目标图像大小
if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate(
image_inputs,
image_reconstructions.shape[2:],
mode="bicubic", # 使用双三次插值法
antialias=True, # 使用抗锯齿
)
# 如果需要将目标图像缩放到输入图像大小
elif self.scale_tgt_to_input_size:
image_reconstructions = torch.nn.functional.interpolate(
image_reconstructions,
image_inputs.shape[2:],
mode="bicubic", # 使用双三次插值法
antialias=True, # 使用抗锯齿
)
# 计算与输入图像的感知损失
perceptual_loss2 = self.perceptual_loss(
image_inputs.contiguous(), image_reconstructions.contiguous()
)
# 更新总损失,加入输入的感知损失
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
# 将输入的感知损失的平均值添加到日志中
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
# 返回总损失和日志信息
return loss, log
# 定义一个带有判别器的通用 LPIPS 类,继承自 nn.Module
class GeneralLPIPSWithDiscriminator(nn.Module):
# 初始化方法,接收多个参数以配置模型
def __init__(
self,
disc_start: int, # 判别器开始训练的迭代次数
logvar_init: float = 0.0, # 日志方差的初始值
pixelloss_weight=1.0, # 像素损失的权重
disc_num_layers: int = 3, # 判别器的层数
disc_in_channels: int = 3, # 判别器输入的通道数
disc_factor: float = 1.0, # 判别器的缩放因子
disc_weight: float = 1.0, # 判别器损失的权重
perceptual_weight: float = 1.0, # 感知损失的权重
disc_loss: str = "hinge", # 判别器使用的损失类型
scale_input_to_tgt_size: bool = False, # 是否将输入缩放到目标大小
dims: int = 2, # 数据的维度
learn_logvar: bool = False, # 是否学习日志方差
regularization_weights: Union[None, dict] = None, # 正则化权重
):
# 调用父类的初始化方法
super().__init__()
self.dims = dims # 保存维度信息
# 如果维度大于2,打印警告信息
if self.dims > 2:
print(
f"running with dims={dims}. This means that for perceptual loss calculation, "
f"the LPIPS loss will be applied to each frame independently. "
)
self.scale_input_to_tgt_size = scale_input_to_tgt_size # 保存输入缩放标志
# 确保判别器损失类型为 hinge 或 vanilla
assert disc_loss in ["hinge", "vanilla"]
self.pixel_weight = pixelloss_weight # 保存像素损失权重
self.perceptual_loss = LPIPS().eval() # 初始化 LPIPS 感知损失并设置为评估模式
self.perceptual_weight = perceptual_weight # 保存感知损失权重
# 输出日志方差,作为可学习的参数
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.learn_logvar = learn_logvar # 保存是否学习日志方差的标志
# 初始化 NLayerDiscriminator 作为判别器
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
).apply(weights_init) # 应用权重初始化
self.discriminator_iter_start = disc_start # 保存判别器开始训练的迭代次数
# 根据损失类型选择合适的损失函数
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor # 保存判别器缩放因子
self.discriminator_weight = disc_weight # 保存判别器损失权重
self.regularization_weights = default(regularization_weights, {}) # 设置正则化权重,默认为空字典
# 获取可训练的参数
def get_trainable_parameters(self) -> Any:
return self.discriminator.parameters() # 返回判别器的参数
# 获取可训练的自编码器参数
def get_trainable_autoencoder_parameters(self) -> Any:
if self.learn_logvar: # 如果学习日志方差
yield self.logvar # 生成日志方差
yield from () # 返回空生成器
# 计算自适应权重
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None: # 如果提供了最后一层
# 计算负对数似然损失的梯度
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
# 计算生成器损失的梯度
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
# 如果没有提供最后一层,使用类属性中的最后一层
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
# 计算判别器权重
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() # 限制权重范围并分离计算图
d_weight = d_weight * self.discriminator_weight # 应用判别器权重
return d_weight # 返回计算得到的权重
# 前向传播方法
def forward(
self,
regularization_log, # 正则化日志
inputs, # 输入数据
reconstructions, # 重建数据
optimizer_idx, # 优化器索引
global_step, # 全局步数
last_layer=None, # 最后一层(可选)
split="train", # 数据集划分(训练或验证)
weights=None, # 权重(可选)
.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\loss\lpips.py
# 从 https://github.com/richzhang/PerceptualSimilarity/tree/master/models 中剥离的版本
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
# 从 collections 模块导入 namedtuple
from collections import namedtuple
# 导入 PyTorch 和神经网络模块
import torch
import torch.nn as nn
# 从 torchvision 导入模型
from torchvision import models
# 从上层目录的 util 模块导入获取检查点路径的函数
from ..util import get_ckpt_path
# 定义 LPIPS 类,继承自 nn.Module
class LPIPS(nn.Module):
# 学习的感知度量
def __init__(self, use_dropout=True):
# 调用父类构造函数
super().__init__()
# 初始化缩放层
self.scaling_layer = ScalingLayer()
# 定义特征通道数量,针对 VGG16 特征
self.chns = [64, 128, 256, 512, 512] # vg16 features
# 加载预训练的 VGG16 模型,不计算其梯度
self.net = vgg16(pretrained=True, requires_grad=False)
# 为每个通道初始化线性层
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
# 从预训练模型中加载权重
self.load_from_pretrained()
# 禁用所有参数的梯度更新
for param in self.parameters():
param.requires_grad = False
# 从预训练模型加载权重的方法
def load_from_pretrained(self, name="vgg_lpips"):
# 获取检查点文件的路径
ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
# 加载权重到当前模型中,严格匹配状态字典
self.load_state_dict(
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
)
# 打印加载的模型路径
print("loaded pretrained LPIPS loss from {}".format(ckpt))
# 类方法,用于从预训练模型创建实例
@classmethod
def from_pretrained(cls, name="vgg_lpips"):
# 检查模型名称是否有效
if name != "vgg_lpips":
raise NotImplementedError
# 创建当前类的实例
model = cls()
# 获取检查点路径
ckpt = get_ckpt_path(name)
# 加载权重到模型中
model.load_state_dict(
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
)
# 返回模型实例
return model
# 前向传播方法
def forward(self, input, target):
# 对输入和目标进行缩放处理
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
# 通过网络提取特征
outs0, outs1 = self.net(in0_input), self.net(in1_input)
# 初始化特征和差异字典
feats0, feats1, diffs = {}, {}, {}
# 收集线性层
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
# 遍历特征通道
for kk in range(len(self.chns)):
# 标准化特征并计算差异
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
outs1[kk]
)
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
# 计算每个通道的结果
res = [
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
for kk in range(len(self.chns))
]
# 初始化最终值
val = res[0]
# 累加结果
for l in range(1, len(self.chns)):
val += res[l]
# 返回最终的值
return val
# 定义缩放层类,继承自 nn.Module
class ScalingLayer(nn.Module):
# 构造函数
def __init__(self):
# 调用父类构造函数
super(ScalingLayer, self).__init__()
# 注册偏移量的缓冲区
self.register_buffer(
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
# 注册缩放因子的缓冲区
self.register_buffer(
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)
# 前向传播方法
def forward(self, inp):
# 进行缩放操作
return (inp - self.shift) / self.scale
# 定义线性层类,继承自 nn.Module
class NetLinLayer(nn.Module):
"""一个单独的线性层,执行 1x1 卷积"""
# 初始化网络层,接收输入通道数、输出通道数和是否使用 dropout
def __init__(self, chn_in, chn_out=1, use_dropout=False):
# 调用父类的初始化方法
super(NetLinLayer, self).__init__()
# 根据是否使用 dropout 创建层列表
layers = (
[
nn.Dropout(), # 添加 dropout 层
]
if (use_dropout) # 如果使用 dropout
else [] # 否则为空列表
)
# 添加卷积层到层列表,输入通道为 chn_in,输出通道为 chn_out,卷积核大小为 1
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
# 将层列表包装成一个顺序容器,便于按顺序调用各层
self.model = nn.Sequential(*layers)
# 定义一个 VGG16 类,继承自 PyTorch 的 Module 类
class vgg16(torch.nn.Module):
# 初始化函数,接受是否需要梯度和是否使用预训练模型的参数
def __init__(self, requires_grad=False, pretrained=True):
# 调用父类的初始化函数
super(vgg16, self).__init__()
# 获取预训练的 VGG16 特征提取层
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
# 定义五个序列层
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
# 设置切片的数量为 5
self.N_slices = 5
# 将前4层添加到 slice1 中
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
# 将第 4 到 8 层添加到 slice2 中
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
# 将第 9 到 15 层添加到 slice3 中
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
# 将第 16 到 22 层添加到 slice4 中
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
# 将第 23 到 29 层添加到 slice5 中
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
# 如果不需要梯度,则冻结所有参数
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
# 定义前向传播函数
def forward(self, X):
# 将输入通过 slice1 层
h = self.slice1(X)
# 保存第一层的输出
h_relu1_2 = h
# 将输出通过 slice2 层
h = self.slice2(h)
# 保存第二层的输出
h_relu2_2 = h
# 将输出通过 slice3 层
h = self.slice3(h)
# 保存第三层的输出
h_relu3_3 = h
# 将输出通过 slice4 层
h = self.slice4(h)
# 保存第四层的输出
h_relu4_3 = h
# 将输出通过 slice5 层
h = self.slice5(h)
# 保存第五层的输出
h_relu5_3 = h
# 创建一个命名元组来存储不同层的输出
vgg_outputs = namedtuple(
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
)
# 将各层输出组合成一个元组
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
# 返回元组
return out
# 定义一个函数,用于规范化张量
def normalize_tensor(x, eps=1e-10):
# 计算张量的 L2 范数
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
# 返回规范化后的张量,避免除以零
return x / (norm_factor + eps)
# 定义一个空间平均函数
def spatial_average(x, keepdim=True):
# 在高和宽维度上计算平均值
return x.mean([2, 3], keepdim=keepdim)
.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\loss\__init__.py
.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\model\model.py
# 导入 functools 模块,用于函数工具
import functools
# 导入 nn 模块,用于构建神经网络
import torch.nn as nn
# 从上级目录的 util 模块导入 ActNorm
from ..util import ActNorm
# 定义权重初始化函数
def weights_init(m):
# 获取模块的类名
classname = m.__class__.__name__
# 如果类名包含 "Conv",则初始化卷积层的权重
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
# 如果类名包含 "BatchNorm",则初始化批归一化层的权重和偏置
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# 定义 NLayerDiscriminator 类,继承自 nn.Module
class NLayerDiscriminator(nn.Module):
"""定义一个 PatchGAN 判别器,如 Pix2Pix 所示
--> 参见 https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
# 构造函数,初始化参数
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""构造一个 PatchGAN 判别器
参数:
input_nc (int) -- 输入图像的通道数
ndf (int) -- 最后一个卷积层中的滤波器数量
n_layers (int) -- 判别器中的卷积层数量
norm_layer -- 归一化层
"""
# 调用父类的构造函数
super(NLayerDiscriminator, self).__init__()
# 根据是否使用 ActNorm 来选择归一化层
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
# 如果归一化层是 functools.partial 类型,判断是否使用偏置
if (
type(norm_layer) == functools.partial
): # 不需要使用偏置,因为 BatchNorm2d 有仿射参数
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4 # 卷积核的大小
padw = 1 # 卷积层的填充
# 初始化序列,添加第一个卷积层和 LeakyReLU 激活函数
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True),
]
nf_mult = 1 # 当前滤波器倍数
nf_mult_prev = 1 # 上一层的滤波器倍数
# 逐渐增加滤波器数量,构建后续的卷积层
for n in range(1, n_layers): # 逐渐增加滤波器数量
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8) # 滤波器倍数最大为 8
sequence += [
nn.Conv2d(
ndf * nf_mult_prev, # 输入通道数
ndf * nf_mult, # 输出通道数
kernel_size=kw, # 卷积核大小
stride=2, # 步幅
padding=padw, # 填充
bias=use_bias, # 是否使用偏置
),
norm_layer(ndf * nf_mult), # 添加归一化层
nn.LeakyReLU(0.2, True), # 添加 LeakyReLU 激活函数
]
nf_mult_prev = nf_mult # 更新上一个滤波器倍数
nf_mult = min(2**n_layers, 8) # 计算最终滤波器倍数
sequence += [
nn.Conv2d(
ndf * nf_mult_prev, # 输入通道数
ndf * nf_mult, # 输出通道数
kernel_size=kw, # 卷积核大小
stride=1, # 步幅
padding=padw, # 填充
bias=use_bias, # 是否使用偏置
),
norm_layer(ndf * nf_mult), # 添加归一化层
nn.LeakyReLU(0.2, True), # 添加 LeakyReLU 激活函数
]
sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
] # 输出 1 通道的预测图
# 将所有层组合成一个序列
self.main = nn.Sequential(*sequence)
# 定义前向传播函数
def forward(self, input):
"""标准前向传播。"""
return self.main(input) # 将输入传入序列并返回输出
.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\model\__init__.py
.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\util.py
# 导入所需的库
import hashlib # 导入 hashlib 库用于计算文件的 MD5 哈希
import os # 导入 os 库用于文件和目录操作
import requests # 导入 requests 库用于发送 HTTP 请求
import torch # 导入 PyTorch 库用于深度学习
import torch.nn as nn # 导入 nn 模块用于构建神经网络
from tqdm import tqdm # 导入 tqdm 库用于显示进度条
# 定义模型 URL 映射字典
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
# 定义模型检查点文件名映射字典
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
# 定义模型 MD5 哈希值映射字典
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
# 定义下载函数
def download(url, local_path, chunk_size=1024):
# 创建本地路径的父目录,如果不存在则创建
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
# 发送 GET 请求以流式下载文件
with requests.get(url, stream=True) as r:
# 获取响应头中的内容长度
total_size = int(r.headers.get("content-length", 0))
# 使用 tqdm 显示下载进度条
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
# 以二进制写入模式打开本地文件
with open(local_path, "wb") as f:
# 分块读取响应内容
for data in r.iter_content(chunk_size=chunk_size):
# 如果读取到数据,则写入文件
if data:
f.write(data) # 写入数据到文件
pbar.update(chunk_size) # 更新进度条
# 定义 MD5 哈希函数
def md5_hash(path):
# 以二进制读取模式打开指定路径的文件
with open(path, "rb") as f:
content = f.read() # 读取文件内容
# 返回文件内容的 MD5 哈希值
return hashlib.md5(content).hexdigest()
# 定义获取检查点路径的函数
def get_ckpt_path(name, root, check=False):
# 确保给定的模型名称在 URL 映射中
assert name in URL_MAP
# 组合根目录和检查点文件名,形成完整路径
path = os.path.join(root, CKPT_MAP[name])
# 检查文件是否存在或是否需要重新下载
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
# 打印下载信息
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path) # 下载文件
md5 = md5_hash(path) # 计算下载文件的 MD5 哈希值
# 确保下载的文件 MD5 值与预期匹配
assert md5 == MD5_MAP[name], md5
return path # 返回检查点路径
# 定义一个自定义的神经网络模块
class ActNorm(nn.Module):
# 构造函数,初始化参数
def __init__(
self, num_features, logdet=False, affine=True, allow_reverse_init=False
):
assert affine # 确保启用仿射变换
super().__init__() # 调用父类构造函数
self.logdet = logdet # 保存 logdet 标志
# 定义可学习的均值参数
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
# 定义可学习的缩放参数
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.allow_reverse_init = allow_reverse_init # 保存是否允许反向初始化标志
# 注册一个缓冲区,用于记录初始化状态
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
# 定义初始化函数
def initialize(self, input):
with torch.no_grad(): # 在不计算梯度的上下文中
# 将输入张量重排列并展平
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
# 计算展平后的均值
mean = (
flatten.mean(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
# 计算展平后的标准差
std = (
flatten.std(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
# 将均值复制到 loc 参数
self.loc.data.copy_(-mean)
# 将标准差的倒数复制到 scale 参数
self.scale.data.copy_(1 / (std + 1e-6))
# 定义前向传播函数,接受输入和反向标志
def forward(self, input, reverse=False):
# 如果反向标志为真,调用反向函数处理输入
if reverse:
return self.reverse(input)
# 检查输入的形状是否为二维
if len(input.shape) == 2:
# 将二维输入扩展为四维,增加两个新的维度
input = input[:, :, None, None]
squeeze = True # 标记为需要压缩
else:
squeeze = False # 不需要压缩
# 解包输入的高度和宽度
_, _, height, width = input.shape
# 如果处于训练状态且尚未初始化
if self.training and self.initialized.item() == 0:
# 初始化参数
self.initialize(input)
# 标记为已初始化
self.initialized.fill_(1)
# 根据比例因子和位移量调整输入
h = self.scale * (input + self.loc)
# 如果需要压缩,移除最后两个维度
if squeeze:
h = h.squeeze(-1).squeeze(-1)
# 如果需要计算对数行列式
if self.logdet:
# 计算比例因子的绝对值的对数
log_abs = torch.log(torch.abs(self.scale))
# 计算对数行列式的值
logdet = height * width * torch.sum(log_abs)
# 生成与批量大小相同的对数行列式张量
logdet = logdet * torch.ones(input.shape[0]).to(input)
# 返回调整后的输出和对数行列式
return h, logdet
# 返回调整后的输出
return h
# 定义反向传播函数,接受输出
def reverse(self, output):
# 如果处于训练状态且尚未初始化
if self.training and self.initialized.item() == 0:
# 如果不允许在反向方向初始化,则抛出错误
if not self.allow_reverse_init:
raise RuntimeError(
"Initializing ActNorm in reverse direction is "
"disabled by default. Use allow_reverse_init=True to enable."
)
else:
# 初始化参数
self.initialize(output)
# 标记为已初始化
self.initialized.fill_(1)
# 检查输出的形状是否为二维
if len(output.shape) == 2:
# 将二维输出扩展为四维,增加两个新的维度
output = output[:, :, None, None]
squeeze = True # 标记为需要压缩
else:
squeeze = False # 不需要压缩
# 根据比例因子和位移量调整输出
h = output / self.scale - self.loc
# 如果需要压缩,移除最后两个维度
if squeeze:
h = h.squeeze(-1).squeeze(-1)
# 返回调整后的输出
return h
.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\vqperceptual.py
# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的功能性模块
import torch.nn.functional as F
# 定义对抗训练中的判别器损失函数(hinge 损失)
def hinge_d_loss(logits_real, logits_fake):
# 计算真实样本的损失,使用 ReLU 激活函数
loss_real = torch.mean(F.relu(1.0 - logits_real))
# 计算假样本的损失,使用 ReLU 激活函数
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
# 计算总的判别器损失,取真实和假样本损失的平均值
d_loss = 0.5 * (loss_real + loss_fake)
# 返回判别器损失
return d_loss
# 定义对抗训练中的另一种判别器损失函数(vanilla 损失)
def vanilla_d_loss(logits_real, logits_fake):
# 计算总的判别器损失,使用 softplus 函数
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) # 计算真实样本的 softplus 损失
+ torch.mean(torch.nn.functional.softplus(logits_fake)) # 计算假样本的 softplus 损失
)
# 返回判别器损失
return d_loss
.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\__init__.py
请提供需要添加注释的代码片段。
.\cogview3-finetune\sat\sgm\modules\autoencoding\regularizers\__init__.py
# 导入抽象方法的库
from abc import abstractmethod
# 导入任意类型和元组类型
from typing import Any, Tuple
# 导入 PyTorch 库
import torch
# 导入神经网络模块
import torch.nn as nn
# 导入功能模块
import torch.nn.functional as F
# 导入对角高斯分布类
from ....modules.distributions.distributions import DiagonalGaussianDistribution
# 定义抽象正则化器类,继承自 nn.Module
class AbstractRegularizer(nn.Module):
# 初始化方法
def __init__(self):
super().__init__() # 调用父类构造方法
# 前向传播方法,接受一个张量,返回一个张量和字典
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
raise NotImplementedError() # 抛出未实现错误
# 抽象方法,获取可训练的参数
@abstractmethod
def get_trainable_parameters(self) -> Any:
raise NotImplementedError() # 抛出未实现错误
# 定义对角高斯正则化器类,继承自 AbstractRegularizer
class DiagonalGaussianRegularizer(AbstractRegularizer):
# 初始化方法,接受一个布尔参数,默认值为 True
def __init__(self, sample: bool = True):
super().__init__() # 调用父类构造方法
self.sample = sample # 存储是否采样的参数
# 获取可训练参数的方法,返回一个生成器
def get_trainable_parameters(self) -> Any:
yield from () # 生成器为空,表示没有可训练参数
# 前向传播方法,接受一个张量,返回一个张量和字典
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict() # 创建一个空字典,用于记录日志
posterior = DiagonalGaussianDistribution(z) # 创建对角高斯分布实例
if self.sample: # 如果需要采样
z = posterior.sample() # 从后验分布中采样
else: # 如果不需要采样
z = posterior.mode() # 取后验分布的众数
kl_loss = posterior.kl() # 计算 Kullback-Leibler 散度损失
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] # 计算均值损失
log["kl_loss"] = kl_loss # 将 KL 损失记录到日志中
return z, log # 返回采样或众数以及日志
# 定义测量困惑度的函数,接受预测的索引和质心数量
def measure_perplexity(predicted_indices, num_centroids):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# 评估聚类的困惑度。当困惑度等于 num_embeddings 时,所有聚类被完全均匀使用
encodings = (
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) # 将预测索引转换为独热编码
)
avg_probs = encodings.mean(0) # 计算每个质心的平均概率
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() # 计算困惑度
cluster_use = torch.sum(avg_probs > 0) # 计算被使用的聚类数量
return perplexity, cluster_use # 返回困惑度和聚类使用情况
.\cogview3-finetune\sat\sgm\modules\autoencoding\__init__.py
请提供需要注释的代码,我将帮助你添加详细的注释。
.\cogview3-finetune\sat\sgm\modules\diffusionmodules\denoiser.py
# 导入所需的类型定义,便于后续使用
from typing import Dict, Union
# 导入 PyTorch 及其神经网络模块
import torch
import torch.nn as nn
# 从上层目录导入实用工具函数
from ...util import append_dims, instantiate_from_config
# 定义去噪器类,继承自 nn.Module
class Denoiser(nn.Module):
# 初始化方法,接受加权配置和缩放配置
def __init__(self, weighting_config, scaling_config):
# 调用父类构造函数
super().__init__()
# 根据加权配置实例化加权模块
self.weighting = instantiate_from_config(weighting_config)
# 根据缩放配置实例化缩放模块
self.scaling = instantiate_from_config(scaling_config)
# 可能对 sigma 进行量化的占位符方法
def possibly_quantize_sigma(self, sigma):
return sigma # 返回未修改的 sigma
# 可能对噪声 c_noise 进行量化的占位符方法
def possibly_quantize_c_noise(self, c_noise):
return c_noise # 返回未修改的 c_noise
# 计算加权后的 sigma 值
def w(self, sigma):
return self.weighting(sigma) # 返回加权后的 sigma
# 前向传播方法,定义模型如何处理输入
def forward(
self,
network: nn.Module, # 网络模型
input: torch.Tensor, # 输入张量
sigma: torch.Tensor, # sigma 张量
cond: Dict, # 条件字典
**additional_model_inputs, # 其他模型输入
) -> torch.Tensor:
# 可能对 sigma 进行量化
sigma = self.possibly_quantize_sigma(sigma)
# 获取 sigma 的形状
sigma_shape = sigma.shape
# 调整 sigma 的维度,以匹配输入的维度
sigma = append_dims(sigma, input.ndim)
# 使用缩放模块处理 sigma 和额外模型输入,获取多个输出
c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs)
# 可能对噪声 c_noise 进行量化,并调整形状
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
# 返回经过网络处理的结果,结合输入和噪声
return (
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
+ input * c_skip
)
# 定义离散去噪器类,继承自 Denoiser
class DiscreteDenoiser(Denoiser):
# 初始化方法,接受多个配置参数
def __init__(
self,
weighting_config, # 加权配置
scaling_config, # 缩放配置
num_idx, # 索引数量
discretization_config, # 离散化配置
do_append_zero=False, # 是否附加零
quantize_c_noise=True, # 是否量化 c_noise
flip=True, # 是否翻转
):
# 调用父类构造函数
super().__init__(weighting_config, scaling_config)
# 根据离散化配置实例化 sigma
sigmas = instantiate_from_config(discretization_config)(
num_idx, do_append_zero=do_append_zero, flip=flip
)
# 保存 sigma 值
self.sigmas = sigmas
# self.register_buffer("sigmas", sigmas) # (可选)注册一个持久化的缓冲区
# 设置是否量化 c_noise
self.quantize_c_noise = quantize_c_noise
# 将 sigma 转换为索引的方法
def sigma_to_idx(self, sigma):
# 计算 sigma 与 sigma 列表的距离
dists = sigma - self.sigmas.to(sigma.device)[:, None]
# 返回距离最小的索引
return dists.abs().argmin(dim=0).view(sigma.shape)
# 将索引转换为 sigma 的方法
def idx_to_sigma(self, idx):
return self.sigmas.to(idx.device)[idx] # 根据索引返回对应的 sigma
# 可能对 sigma 进行量化的重写方法
def possibly_quantize_sigma(self, sigma):
return self.idx_to_sigma(self.sigma_to_idx(sigma)) # 返回量化后的 sigma
# 可能对噪声 c_noise 进行量化的重写方法
def possibly_quantize_c_noise(self, c_noise):
if self.quantize_c_noise: # 如果选择量化 c_noise
return self.sigma_to_idx(c_noise) # 返回量化后的索引
else:
return c_noise # 返回未修改的 c_noise
.\cogview3-finetune\sat\sgm\modules\diffusionmodules\denoiser_scaling.py
# 从 abc 模块导入抽象基类和抽象方法
from abc import ABC, abstractmethod
# 从 typing 模块导入任意类型和元组
from typing import Any, Tuple
# 导入 PyTorch 库
import torch
# 定义一个抽象基类 DenoiserScaling
class DenoiserScaling(ABC):
# 定义一个抽象方法 __call__
@abstractmethod
def __call__(
self, sigma: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# 抽象方法没有具体实现
pass
# 定义 EDMScaling 类
class EDMScaling:
# 初始化方法,接受一个 sigma 数据,默认值为 0.5
def __init__(self, sigma_data: float = 0.5):
# 将 sigma_data 保存为实例变量
self.sigma_data = sigma_data
# 定义一个可调用方法 __call__
def __call__(
self, sigma: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# 计算 c_skip 的值
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
# 计算 c_out 的值
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
# 计算 c_in 的值
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
# 计算 c_noise 的值
c_noise = 0.25 * sigma.log()
# 返回四个计算结果
return c_skip, c_out, c_in, c_noise
# 定义 EpsScaling 类
class EpsScaling:
# 定义一个可调用方法 __call__
def __call__(
self, sigma: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# 创建与 sigma 相同形状的全 1 张量作为 c_skip
c_skip = torch.ones_like(sigma, device=sigma.device)
# 计算 c_out 的值为 -sigma
c_out = -sigma
# 计算 c_in 的值
c_in = 1 / (sigma**2 + 1.0) ** 0.5
# 复制 sigma 作为 c_noise
c_noise = sigma.clone()
# 返回四个计算结果
return c_skip, c_out, c_in, c_noise
# 定义 VScaling 类
class VScaling:
# 定义一个可调用方法 __call__
def __call__(
self, sigma: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# 计算 c_skip 的值
c_skip = 1.0 / (sigma**2 + 1.0)
# 计算 c_out 的值
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
# 计算 c_in 的值
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
# 复制 sigma 作为 c_noise
c_noise = sigma.clone()
# 返回四个计算结果
return c_skip, c_out, c_in, c_noise
# 定义 VScalingWithEDMcNoise 类,继承自 DenoiserScaling
class VScalingWithEDMcNoise(DenoiserScaling):
# 定义一个可调用方法 __call__
def __call__(
self, sigma: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# 计算 c_skip 的值
c_skip = 1.0 / (sigma**2 + 1.0)
# 计算 c_out 的值
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
# 计算 c_in 的值
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
# 计算 c_noise 的值
c_noise = 0.25 * sigma.log()
# 返回四个计算结果
return c_skip, c_out, c_in, c_noise
# 定义 ZeroSNRScaling 类,类似于 VScaling
class ZeroSNRScaling: # similar to VScaling
# 定义一个可调用方法 __call__
def __call__(
self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# 将 alphas_cumprod_sqrt 作为 c_skip
c_skip = alphas_cumprod_sqrt
# 计算 c_out 的值
c_out = - (1 - alphas_cumprod_sqrt**2) ** 0.5
# 创建与 alphas_cumprod_sqrt 相同形状的全 1 张量作为 c_in
c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device)
# 复制额外输入的 'idx' 作为 c_noise
c_noise = additional_model_inputs['idx'].clone()
# 返回四个计算结果
return c_skip, c_out, c_in, c_noise
.\cogview3-finetune\sat\sgm\modules\diffusionmodules\denoiser_weighting.py
# 导入 PyTorch 库
import torch
# 定义一个单位加权类
class UnitWeighting:
# 定义可调用方法,接收参数 sigma
def __call__(self, sigma):
# 返回与 sigma 形状相同的全 1 张量,设备与 sigma 相同
return torch.ones_like(sigma, device=sigma.device)
# 定义 EDM 加权类
class EDMWeighting:
# 初始化方法,设置 sigma_data 的默认值为 0.5
def __init__(self, sigma_data=0.5):
# 将传入的 sigma_data 存储为实例变量
self.sigma_data = sigma_data
# 定义可调用方法,接收参数 sigma
def __call__(self, sigma):
# 根据公式计算加权值并返回
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
# 定义 V 加权类,继承自 EDMWeighting
class VWeighting(EDMWeighting):
# 初始化方法
def __init__(self):
# 调用父类初始化方法,设置 sigma_data 为 1.0
super().__init__(sigma_data=1.0)
# 定义 Eps 加权类
class EpsWeighting:
# 定义可调用方法,接收参数 sigma
def __call__(self, sigma):
# 返回 sigma 的 -2 次幂
return sigma**-2.0
.\cogview3-finetune\sat\sgm\modules\diffusionmodules\discretizer.py
# 从 abc 模块导入抽象方法,用于定义抽象基类
from abc import abstractmethod
# 从 functools 模块导入 partial,用于创建偏函数
from functools import partial
# 导入 numpy 库,主要用于数值计算
import numpy as np
# 导入 torch 库,主要用于深度学习和张量操作
import torch
# 从自定义模块中导入 make_beta_schedule 函数,用于生成 beta 时间表
from ...modules.diffusionmodules.util import make_beta_schedule
# 从自定义模块中导入 append_zero 函数,用于处理数组
from ...util import append_zero
# 定义函数 generate_roughly_equally_spaced_steps,用于生成大致均匀间隔的步骤
def generate_roughly_equally_spaced_steps(
num_substeps: int, max_step: int
) -> np.ndarray:
# 生成从 max_step-1 到 0 的均匀间隔数组,包含 num_substeps 个元素,并将结果反转
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
# 定义函数 sub_generate_roughly_equally_spaced_steps,用于生成两个子步骤的均匀间隔
def sub_generate_roughly_equally_spaced_steps(
num_substeps_1: int, num_substeps_2: int, max_step: int
) -> np.ndarray:
# 生成第二组子步骤的均匀间隔
substeps_2 = np.linspace(max_step - 1, 0, num_substeps_2, endpoint=False).astype(int)[::-1]
# 生成第一组子步骤的均匀间隔
substeps_1 = np.linspace(num_substeps_2 - 1, 0, num_substeps_1, endpoint=False).astype(int)[::-1]
# 返回根据第一组子步骤索引获取第二组子步骤的数组
return substeps_2[substeps_1]
# 定义离散化的抽象基类 Discretization
class Discretization:
# 定义可调用方法,接受 n、do_append_zero、device 和 flip 参数
def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
# 调用 get_sigmas 方法获取 sigma 值
sigmas = self.get_sigmas(n, device=device)
# 如果 do_append_zero 为真,向 sigmas 添加零
sigmas = append_zero(sigmas) if do_append_zero else sigmas
# 根据 flip 参数决定是否反转 sigmas
return sigmas if not flip else torch.flip(sigmas, (0,))
# 定义抽象方法 get_sigmas,必须在子类中实现
@abstractmethod
def get_sigmas(self, n, device):
pass
# 定义 EDMDiscretization 类,继承自 Discretization
class EDMDiscretization(Discretization):
# 初始化方法,设置 sigma 的最小值、最大值和 rho 值
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho
# 实现抽象方法 get_sigmas,计算 sigma 值
def get_sigmas(self, n, device="cpu"):
# 生成从 0 到 1 的等间隔张量
ramp = torch.linspace(0, 1, n, device=device)
# 计算最小和最大 rho 的倒数
min_inv_rho = self.sigma_min ** (1 / self.rho)
max_inv_rho = self.sigma_max ** (1 / self.rho)
# 根据公式计算 sigmas
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
# 返回计算得到的 sigmas
return sigmas
# 定义 LegacyDDPMDiscretization 类,继承自 Discretization
class LegacyDDPMDiscretization(Discretization):
# 初始化方法,设置线性开始、结束和时间步数
def __init__(
self,
linear_start=0.00085,
linear_end=0.0120,
num_timesteps=1000,
):
# 调用父类初始化方法
super().__init__()
self.num_timesteps = num_timesteps
# 使用 make_beta_schedule 生成 beta 时间表
betas = make_beta_schedule(
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
)
# 计算 alphas
alphas = 1.0 - betas
# 计算累积的 alphas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
# 创建将 numpy 数组转换为 torch 张量的偏函数
self.to_torch = partial(torch.tensor, dtype=torch.float32)
# 实现抽象方法 get_sigmas,根据 n 计算 sigma 值
def get_sigmas(self, n, device="cpu"):
# 如果 n 小于总时间步数
if n < self.num_timesteps:
# 生成大致均匀间隔的时间步
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
# 获取对应的累积 alphas
alphas_cumprod = self.alphas_cumprod[timesteps]
# 如果 n 等于总时间步数
elif n == self.num_timesteps:
alphas_cumprod = self.alphas_cumprod
# 如果 n 大于总时间步数,抛出异常
else:
raise ValueError
# 创建将 numpy 数组转换为 torch 张量的偏函数,指定设备
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
# 计算 sigmas 值
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
# 返回反转的 sigmas
return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029
# 定义 ZeroSNRDDPMDiscretization 类,继承自 Discretization
class ZeroSNRDDPMDiscretization(Discretization):
# 初始化方法,设置线性开始、结束、时间步数和 shift_scale
def __init__(
self,
linear_start=0.00085,
linear_end=0.0120,
num_timesteps=1000,
shift_scale=1.,
# 初始化父类
):
super().__init__()
# 设置时间步数
self.num_timesteps = num_timesteps
# 生成线性调度的 beta 值
betas = make_beta_schedule(
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
)
# 计算 alpha 值
alphas = 1.0 - betas
# 计算累积的 alpha 值
self.alphas_cumprod = np.cumprod(alphas, axis=0)
# 将数据转换为 torch 张量的函数
self.to_torch = partial(torch.tensor, dtype=torch.float32)
# 对累积的 alpha 值进行缩放
self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1-shift_scale) * self.alphas_cumprod)
# 获取 sigma 值的方法
def get_sigmas(self, n, device="cpu", return_idx=False):
# 判断请求的时间步数是否小于总时间步数
if n < self.num_timesteps:
# 生成等间距的时间步
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
# 获取对应的累积 alpha 值
alphas_cumprod = self.alphas_cumprod[timesteps]
# 判断请求的时间步数是否等于总时间步数
elif n == self.num_timesteps:
alphas_cumprod = self.alphas_cumprod
# 如果超出范围,抛出异常
else:
raise ValueError
# 将 alpha 值转换为 torch 张量
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
alphas_cumprod = to_torch(alphas_cumprod)
# 计算累积 alpha 值的平方根
alphas_cumprod_sqrt = alphas_cumprod.sqrt()
# 克隆初始和最终的 alpha 值
alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone()
alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
# 对平方根进行归一化处理
alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T)
# 根据返回标志返回结果
if return_idx:
return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps
else:
# 返回反转的平方根 alpha 值
return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99
# 使对象可调用的方法
def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False):
# 根据返回标志调用获取 sigma 值的方法
if return_idx:
sigmas, idx = self.get_sigmas(n, device=device, return_idx=True)
sigmas = append_zero(sigmas) if do_append_zero else sigmas
# 根据 flip 标志返回结果
return (sigmas, idx) if not flip else (torch.flip(sigmas, (0,)), torch.flip(idx, (0,)))
else:
# 获取 sigma 值并处理
sigmas = self.get_sigmas(n, device=device)
sigmas = append_zero(sigmas) if do_append_zero else sigmas
# 根据 flip 标志返回结果
return sigmas if not flip else torch.flip(sigmas, (0,))
.\cogview3-finetune\sat\sgm\modules\diffusionmodules\dit.py
# 从 omegaconf 导入 DictConfig 类,用于配置管理
from omegaconf import DictConfig
# 从 functools 导入 partial 函数,用于偏函数应用
from functools import partial
# 从 einops 导入 rearrange 函数,用于重排张量
from einops import rearrange
# 导入 numpy 库,用于数值计算
import numpy as np
# 导入 PyTorch 库
import torch
# 从 torch 导入 nn 模块,包含神经网络构建相关的工具
from torch import nn
# 导入 PyTorch 的分布式训练模块
import torch.distributed
# 从 sat.model.base_model 导入 BaseModel 类,作为模型基类
from sat.model.base_model import BaseModel
# 从 sat.model.mixins 导入 BaseMixin 类,用于混入模型功能
from sat.model.mixins import BaseMixin
# 从 sat.ops.layernorm 导入 LayerNorm 类,用于层归一化
from sat.ops.layernorm import LayerNorm
# 从 sat.transformer_defaults 导入默认的 hooks 和注意力函数
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
# 从 sat.mpu.utils 导入用于张量分割的工具
from sat.mpu.utils import split_tensor_along_last_dim
# 从 sgm.util 导入一些工具函数
from sgm.util import (
disabled_train, # 禁用训练的装饰器
instantiate_from_config, # 从配置实例化对象
)
# 从 sgm.modules.diffusionmodules.openaimodel 导入时间步类
from sgm.modules.diffusionmodules.openaimodel import Timestep
# 从 sgm.modules.diffusionmodules.util 导入卷积、线性层和时间步嵌入等工具
from sgm.modules.diffusionmodules.util import (
conv_nd, # 多维卷积函数
linear, # 线性变换函数
timestep_embedding, # 时间步嵌入函数
)
# 定义调制函数,接受输入张量、偏移量和缩放因子
def modulate(x, shift, scale):
# 根据缩放因子和偏移量对输入张量进行调制
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
# 定义解patch化函数,将补丁张量恢复为图像张量
def unpatchify(x, channels, patch_size, height, width):
# 使用 rearrange 函数将补丁张量重排为图像张量
x = rearrange(
x,
"b (h w) (c p1 p2) -> b c (h p1) (w p2)", # 指定输入和输出的张量维度
h=height // patch_size, # 计算行数
w=width // patch_size, # 计算列数
p1=patch_size, # 每个补丁的高度
p2=patch_size, # 每个补丁的宽度
)
# 返回解patch化后的图像张量
return x
# 定义图像补丁嵌入混入类,继承自 BaseMixin
class ImagePatchEmbeddingMixin(BaseMixin):
# 初始化函数,设置输入通道、隐藏层大小、补丁大小等属性
def __init__(
self,
in_channels,
hidden_size,
patch_size,
text_hidden_size=None,
do_rearrange=True,
):
# 调用父类初始化
super().__init__()
# 设置输入通道数
self.in_channels = in_channels
# 设置隐藏层大小
self.hidden_size = hidden_size
# 设置补丁大小
self.patch_size = patch_size
# 设置文本隐藏层大小(如果有的话)
self.text_hidden_size = text_hidden_size
# 设置是否重排张量的标志
self.do_rearrange = do_rearrange
# 初始化线性层,将补丁的通道数映射到隐藏层大小
self.proj = nn.Linear(in_channels * patch_size ** 2, hidden_size)
# 如果提供了文本隐藏层大小,则初始化相应的线性层
if text_hidden_size is not None:
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
# 定义词嵌入前向传播函数,接受输入ID、图像和编码器输出
def word_embedding_forward(self, input_ids, images, encoder_outputs, **kwargs):
# images: B x C x H x W,表示批量图像的形状
# 如果需要重排图像张量
if self.do_rearrange:
# 使用 rearrange 函数将图像重排为补丁格式
patches_images = rearrange(
images, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", # 指定输入和输出的张量维度
p1=self.patch_size, # 每个补丁的高度
p2=self.patch_size, # 每个补丁的宽度
)
else:
# 否则直接使用原始图像张量
patches_images = images
# 通过线性层对补丁图像进行映射
emb = self.proj(patches_images)
# 如果有文本隐藏层大小
if self.text_hidden_size is not None:
# 对编码器输出进行线性映射
text_emb = self.text_proj(encoder_outputs)
# 将文本嵌入与图像嵌入在维度1上进行连接
emb = torch.cat([text_emb, emb], dim=1)
# 返回最终的嵌入结果
return emb
# 定义重新初始化函数
def reinit(self, parent_model=None):
# 获取线性层的权重
w = self.proj.weight.data
# 使用 Xavier 均匀分布初始化权重
nn.init.xavier_uniform_(self.proj.weight)
# 将偏置初始化为零
nn.init.constant_(self.proj.bias, 0)
# 删除 transformer 的词嵌入
del self.transformer.word_embeddings
# 定义获取 2D 正弦余弦位置嵌入的函数
def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
# 创建一个表示网格高度的数组
grid_h = np.arange(grid_height, dtype=np.float32)
# 创建一个表示网格宽度的数组
grid_w = np.arange(grid_width, dtype=np.float32)
# 生成网格坐标,宽度优先
grid = np.meshgrid(grid_w, grid_h) # here w goes first
# 将网格数组堆叠到一个新的维度
grid = np.stack(grid, axis=0)
# 将网格重塑为 [2, 1, grid_height, grid_width] 形状
grid = grid.reshape([2, 1, grid_height, grid_width])
# 从网格生成二维正弦余弦位置嵌入
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
# 如果需要分类标记且额外的标记数大于0,则进行处理
if cls_token and extra_tokens > 0:
# 在位置嵌入前添加额外的零嵌入,形成新的位置嵌入
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
# 返回最终的位置嵌入
return pos_embed
# 从网格生成二维正弦余弦位置嵌入
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
# 确保嵌入维度是偶数
assert embed_dim % 2 == 0
# 使用一半的维度编码网格高度
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
# 使用一半的维度编码网格宽度
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
# 将高度和宽度的嵌入合并
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
# 返回最终的嵌入
return emb
# 从网格生成一维正弦余弦位置嵌入
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: 每个位置的输出维度
pos: 需要编码的位置列表:大小 (M,)
out: (M, D)
"""
# 确保嵌入维度是偶数
assert embed_dim % 2 == 0
# 生成 omega 数组用于位置编码
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
# 计算频率,得到 (D/2,)
omega = 1.0 / 10000 ** omega # (D/2,)
# 将位置调整为一维
pos = pos.reshape(-1) # (M,)
# 计算位置与频率的外积
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
# 计算正弦和余弦嵌入
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
# 将正弦和余弦嵌入合并
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
# 返回最终的嵌入
return emb
# 位置嵌入混合类
class PositionEmbeddingMixin(BaseMixin):
def __init__(
self,
max_height,
max_width,
hidden_size,
text_length=0,
block_size=16,
**kwargs,
):
# 初始化父类
super().__init__()
# 设置最大高度
self.max_height = max_height
# 设置最大宽度
self.max_width = max_width
# 设置隐藏层大小
self.hidden_size = hidden_size
# 设置文本长度
self.text_length = text_length
# 设置块大小
self.block_size = block_size
# 初始化图像位置嵌入参数
self.image_pos_embedding = nn.Parameter(
torch.zeros(self.max_height, self.max_width, hidden_size), requires_grad=False
)
# 前向传播计算位置嵌入
def position_embedding_forward(self, position_ids, target_size, **kwargs):
ret = []
# 遍历目标大小
for h, w in target_size:
# 将高度和宽度除以块大小
h, w = h // self.block_size, w // self.block_size
# 获取图像位置嵌入并重塑
image_pos_embed = self.image_pos_embedding[:h, :w].reshape(h * w, -1)
# 连接文本嵌入与图像嵌入
pos_embed = torch.cat(
[
torch.zeros(
(self.text_length, self.hidden_size),
dtype=image_pos_embed.dtype,
device=image_pos_embed.device,
),
image_pos_embed,
],
dim=0,
)
# 添加到结果列表中
ret.append(pos_embed[None, ...])
# 合并所有位置嵌入
return torch.cat(ret, dim=0)
# 重新初始化位置嵌入
def reinit(self, parent_model=None):
# 删除当前位置嵌入
del self.transformer.position_embeddings
# 获取新的二维正弦余弦位置嵌入
pos_embed = get_2d_sincos_pos_embed(self.image_pos_embedding.shape[-1], self.max_height, self.max_width)
# 重塑位置嵌入为二维形状
pos_embed = pos_embed.reshape(self.max_height, self.max_width, -1)
# 复制新的位置嵌入数据
self.image_pos_embedding.data.copy_(torch.from_numpy(pos_embed).float())
# 最终层混合类
class FinalLayerMixin(BaseMixin):
def __init__(
self,
hidden_size,
time_embed_dim,
patch_size,
block_size,
out_channels,
elementwise_affine=False,
eps=1e-6,
do_unpatchify=True,
):
# 调用父类构造函数进行初始化
super().__init__()
# 设置隐藏层大小
self.hidden_size = hidden_size
# 设置每个补丁的大小
self.patch_size = patch_size
# 设置块的大小
self.block_size = block_size
# 设置输出通道数
self.out_channels = out_channels
# 确定是否进行去补丁处理
self.do_unpatchify = do_unpatchify
# 初始化最终的层归一化,带有可学习的参数
self.norm_final = nn.LayerNorm(
hidden_size,
elementwise_affine=elementwise_affine,
eps=eps,
)
# 创建一个包含SiLU激活和线性层的序列
self.adaln = nn.Sequential(
nn.SiLU(),
nn.Linear(time_embed_dim, 2 * hidden_size),
)
# 初始化线性层以将隐藏状态映射到输出通道
self.linear = nn.Linear(hidden_size, out_channels * patch_size ** 2)
def final_forward(self, logits, emb, text_length, target_size=None, **kwargs):
# 截取logits以获取文本长度后的部分
x = logits[:, text_length:]
# 使用adaln模块对嵌入进行变换并拆分为偏移和缩放
shift, scale = self.adaln(emb).chunk(2, dim=1)
# 对x进行归一化并应用偏移和缩放
x = modulate(self.norm_final(x), shift, scale)
# 通过线性层获得最终输出
x = self.linear(x)
# 如果需要去补丁处理
if self.do_unpatchify:
# 从目标大小中提取高度和宽度
target_height, target_width = target_size[0]
# 断言目标大小必须能被块大小整除
assert (
target_height % self.block_size == 0 and target_width % self.block_size == 0
), "target size must be divisible by block size"
# 计算输出高度和宽度
out_height, out_width = (
target_height // self.block_size * self.patch_size,
target_width // self.block_size * self.patch_size,
)
# 进行去补丁处理,恢复原图
x = unpatchify(
x, channels=self.out_channels, patch_size=self.patch_size, height=out_height, width=out_width
)
# 返回最终输出
return x
def reinit(self, parent_model=None):
# 使用Xavier均匀分布初始化线性层权重
nn.init.xavier_uniform_(self.linear.weight)
# 将线性层偏置初始化为0
nn.init.constant_(self.linear.bias, 0)
# 定义一个混合类 AdalnAttentionMixin,继承自 BaseMixin
class AdalnAttentionMixin(BaseMixin):
# 初始化函数,接受多个参数以设置模型的各项属性
def __init__(
self,
hidden_size, # 隐藏层大小
num_layers, # 层数
time_embed_dim, # 时间嵌入维度
qk_ln=True, # 是否使用查询和键的层归一化
hidden_size_head=None, # 头部的隐藏层大小
elementwise_affine=False, # 是否使用逐元素仿射变换
eps=1e-6, # 用于层归一化的平滑项
):
# 调用父类的初始化方法
super().__init__()
# 创建一个包含多个顺序模块的模块列表,每个模块由 SiLU 激活函数和线性层组成
self.adaln_modules = nn.ModuleList(
[nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
)
# 记录是否使用查询和键的层归一化
self.qk_ln = qk_ln
# 如果使用层归一化,则为查询和键分别创建模块列表
if qk_ln:
# 创建用于查询的层归一化模块列表
self.query_layernorms = nn.ModuleList(
[
LayerNorm(hidden_size_head, elementwise_affine=elementwise_affine, eps=eps)
for _ in range(num_layers) # 为每一层创建一个层归一化模块
]
)
# 创建用于键的层归一化模块列表
self.key_layernorms = nn.ModuleList(
[
LayerNorm(hidden_size_head, elementwise_affine=elementwise_affine, eps=eps)
for _ in range(num_layers) # 为每一层创建一个层归一化模块
]
)
# 定义前向传播的方法,接收隐藏状态、掩码、文本长度等参数
def layer_forward(
self,
hidden_states, # 当前层的隐藏状态
mask, # 掩码,用于忽略特定位置
text_length, # 文本的长度
layer_id, # 当前层的索引
emb, # 嵌入表示
*args, # 可变参数
**kwargs, # 关键字参数
# 定义一个方法的结尾,接受必要的参数
):
# 获取指定层的 Transformer 层
layer = self.transformer.layers[layer_id]
# 获取与该层对应的自适应层归一化模块
adaln_module = self.adaln_modules[layer_id]
# 从自适应层归一化模块中处理输入并将结果分块为 12 个部分
(
shift_msa_img,
scale_msa_img,
gate_msa_img,
shift_mlp_img,
scale_mlp_img,
gate_mlp_img,
shift_msa_txt,
scale_msa_txt,
gate_msa_txt,
shift_mlp_txt,
scale_mlp_txt,
gate_mlp_txt,
) = adaln_module(emb).chunk(12, dim=1)
# 扩展门控张量的维度,以便后续处理
gate_msa_img, gate_mlp_img, gate_msa_txt, gate_mlp_txt = (
gate_msa_img.unsqueeze(1),
gate_mlp_img.unsqueeze(1),
gate_msa_txt.unsqueeze(1),
gate_mlp_txt.unsqueeze(1),
)
# 对输入的隐藏状态进行层归一化处理
attention_input = layer.input_layernorm(hidden_states)
# 对文本输入进行调制,应用相应的偏移和缩放
text_attention_input = modulate(attention_input[:, :text_length], shift_msa_txt, scale_msa_txt)
# 对图像输入进行调制,应用相应的偏移和缩放
image_attention_input = modulate(attention_input[:, text_length:], shift_msa_img, scale_msa_img)
# 将文本和图像的注意力输入合并
attention_input = torch.cat((text_attention_input, image_attention_input), dim=1)
# 计算注意力输出,应用遮罩和层 ID
attention_output = layer.attention(attention_input, mask, layer_id=layer_id, **kwargs)
# 如果层归一化顺序是 "sandwich",则应用第三次层归一化
if self.transformer.layernorm_order == "sandwich":
attention_output = layer.third_layernorm(attention_output)
# 将隐藏状态分为文本和图像部分
text_hidden_states, image_hidden_states = hidden_states[:, :text_length], hidden_states[:, text_length:]
# 将注意力输出分为文本和图像部分
text_attention_output, image_attention_output = (
attention_output[:, :text_length],
attention_output[:, text_length:],
)
# 更新文本隐藏状态,加上文本注意力输出的加权
text_hidden_states = text_hidden_states + gate_msa_txt * text_attention_output
# 更新图像隐藏状态,加上图像注意力输出的加权
image_hidden_states = image_hidden_states + gate_msa_img * image_attention_output
# 合并更新后的隐藏状态
hidden_states = torch.cat((text_hidden_states, image_hidden_states), dim=1)
# 对合并后的隐藏状态进行后注意力层归一化
mlp_input = layer.post_attention_layernorm(hidden_states)
# 对文本输入进行调制,应用相应的偏移和缩放
text_mlp_input = modulate(mlp_input[:, :text_length], shift_mlp_txt, scale_mlp_txt)
# 对图像输入进行调制,应用相应的偏移和缩放
image_mlp_input = modulate(mlp_input[:, text_length:], shift_mlp_img, scale_mlp_img)
# 将文本和图像的 MLP 输入合并
mlp_input = torch.cat((text_mlp_input, image_mlp_input), dim=1)
# 计算 MLP 输出,应用层 ID
mlp_output = layer.mlp(mlp_input, layer_id=layer_id, **kwargs)
# 如果层归一化顺序是 "sandwich",则应用第四次层归一化
if self.transformer.layernorm_order == "sandwich":
mlp_output = layer.fourth_layernorm(mlp_output)
# 将隐藏状态分为文本和图像部分
text_hidden_states, image_hidden_states = hidden_states[:, :text_length], hidden_states[:, text_length:]
# 将 MLP 输出分为文本和图像部分
text_mlp_output, image_mlp_output = mlp_output[:, :text_length], mlp_output[:, text_length:]
# 更新文本隐藏状态,加上文本 MLP 输出的加权
text_hidden_states = text_hidden_states + gate_mlp_txt * text_mlp_output
# 更新图像隐藏状态,加上图像 MLP 输出的加权
image_hidden_states = image_hidden_states + gate_mlp_img * image_mlp_output
# 合并更新后的隐藏状态
hidden_states = torch.cat((text_hidden_states, image_hidden_states), dim=1)
# 返回最终的隐藏状态
return hidden_states
# 定义注意力前向传播函数,接收隐藏状态、掩码、层ID及其他参数
def attention_forward(self, hidden_states, mask, layer_id, **kwargs):
# 获取指定层的注意力模块
attention = self.transformer.layers[layer_id].attention
# 默认的注意力计算函数
attention_fn = attention_fn_default
# 如果注意力模块有自定义的注意力函数,则使用它
if "attention_fn" in attention.hooks:
attention_fn = attention.hooks["attention_fn"]
# 通过隐藏状态计算查询、键、值
qkv = attention.query_key_value(hidden_states)
# 将查询、键、值沿最后一个维度分离
mixed_query_layer, mixed_key_layer, mixed_value_layer = split_tensor_along_last_dim(qkv, 3)
# 根据训练状态选择是否应用 dropout
dropout_fn = attention.attention_dropout if self.training else None
# 转置查询、键、值以便于后续计算
query_layer = attention._transpose_for_scores(mixed_query_layer)
key_layer = attention._transpose_for_scores(mixed_key_layer)
value_layer = attention._transpose_for_scores(mixed_value_layer)
# 如果使用层归一化,应用于查询和键
if self.qk_ln:
query_layernorm = self.query_layernorms[layer_id]
key_layernorm = self.key_layernorms[layer_id]
query_layer = query_layernorm(query_layer)
key_layer = key_layernorm(key_layer)
# 计算上下文层,使用指定的注意力函数
context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kwargs)
# 调整上下文层的维度顺序
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# 创建新的上下文层形状,适配后续计算
new_context_layer_shape = context_layer.size()[:-2] + (attention.hidden_size_per_partition,)
# 重新调整上下文层的形状
context_layer = context_layer.view(*new_context_layer_shape)
# 通过全连接层计算输出
output = attention.dense(context_layer)
# 如果处于训练状态,应用输出的 dropout
if self.training:
output = attention.output_dropout(output)
# 返回最终输出
return output
# 定义多层感知机的前向传播函数
def mlp_forward(self, hidden_states, layer_id, **kwargs):
# 获取指定层的多层感知机模块
mlp = self.transformer.layers[layer_id].mlp
# 通过全连接层将隐藏状态映射到更高维度
intermediate_parallel = mlp.dense_h_to_4h(hidden_states)
# 应用激活函数
intermediate_parallel = mlp.activation_func(intermediate_parallel)
# 将高维结果映射回原维度
output = mlp.dense_4h_to_h(intermediate_parallel)
# 如果处于训练状态,应用 dropout
if self.training:
output = mlp.dropout(output)
# 返回最终输出
return output
# 定义重新初始化函数,接受可选的父模型参数
def reinit(self, parent_model=None):
# 遍历自适应层模块
for layer in self.adaln_modules:
# 将最后一层的权重初始化为 0
nn.init.constant_(layer[-1].weight, 0)
# 将最后一层的偏置初始化为 0
nn.init.constant_(layer[-1].bias, 0)
# 定义一个字符串到数据类型的映射
str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
# 定义一个扩散变换器类,继承自基础模型
class DiffusionTransformer(BaseModel):
# 初始化方法,接受多个参数以配置模型
def __init__(
self,
in_channels, # 输入通道数
out_channels, # 输出通道数
hidden_size, # 隐藏层大小
patch_size, # 图像块大小
num_layers, # 层数
num_attention_heads, # 注意力头数
text_length, # 文本长度
time_embed_dim=None, # 时间嵌入维度,默认为 None
num_classes=None, # 类别数量,默认为 None
adm_in_channels=None, # 自适应输入通道,默认为 None
modules={}, # 额外模块,默认为空字典
dtype="fp32", # 数据类型,默认为 fp32
layernorm_order="pre", # 层归一化顺序,默认为 "pre"
elementwise_affine=False, # 是否启用逐元素仿射,默认为 False
parallel_output=True, # 是否并行输出,默认为 True
block_size=16, # 块大小,默认为 16
**kwargs, # 其他关键字参数
):
# 初始化基类
super().__init__(**kwargs)
# 前向传播方法,定义模型的前向计算
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
# 将输入张量 x 转换为指定的数据类型
x = x.to(self.dtype)
# 获取时间步的嵌入表示
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
# 对时间嵌入进行处理
emb = self.time_embed(t_emb)
# 确保 y 和 x 的批次大小一致
assert y.shape[0] == x.shape[0]
# 将标签嵌入与时间嵌入相加
emb = emb + self.label_emb(y)
# 创建输入 ID、位置 ID 和注意力掩码,均初始化为 1
input_ids = position_ids = attention_mask = torch.ones((1, 1)).to(x.dtype)
# 调用基类的前向方法,传入多个参数以计算输出
output = super().forward(
images=x, # 输入图像
emb=emb, # 嵌入表示
encoder_outputs=context, # 编码器输出
text_length=self.text_length, # 文本长度
input_ids=input_ids, # 输入 ID
position_ids=position_ids, # 位置 ID
attention_mask=attention_mask, # 注意力掩码
**kwargs, # 其他关键字参数
)[0] # 获取输出的第一个元素
# 返回模型的输出
return output