CogView3---CogView-3Plus-微调代码源码解析-二-

CogView3 & CogView-3Plus 微调代码源码解析(二)

.\cogview3-finetune\sat\sgm\models\__init__.py

# 从同一模块导入 AutoencodingEngine 类,用于后续的自动编码器操作
from .autoencoder import AutoencodingEngine

# 注释文本(可能是无关信息或标识符)
#XuDwndGaCFo

.\cogview3-finetune\sat\sgm\modules\attention.py

# 导入数学库
import math
# 从 inspect 模块导入 isfunction 函数,用于检查对象是否为函数
from inspect import isfunction
# 导入 Any 和 Optional 类型
from typing import Any, Optional

# 导入 PyTorch 库
import torch
# 导入 PyTorch 的功能性模块
import torch.nn.functional as F
# 从 einops 库导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 导入版本管理工具
from packaging import version
# 导入 PyTorch 的神经网络模块
from torch import nn

# 检查 PyTorch 版本是否大于或等于 2.0.0
if version.parse(torch.__version__) >= version.parse("2.0.0"):
    # 设置 SDP_IS_AVAILABLE 为 True,表示 SDP 后端可用
    SDP_IS_AVAILABLE = True
    # 从 PyTorch 导入 SDPBackend 和 sdp_kernel
    from torch.backends.cuda import SDPBackend, sdp_kernel

    # 定义后端映射字典,根据不同的后端配置相应的选项
    BACKEND_MAP = {
        SDPBackend.MATH: {
            "enable_math": True,
            "enable_flash": False,
            "enable_mem_efficient": False,
        },
        SDPBackend.FLASH_ATTENTION: {
            "enable_math": False,
            "enable_flash": True,
            "enable_mem_efficient": False,
        },
        SDPBackend.EFFICIENT_ATTENTION: {
            "enable_math": False,
            "enable_flash": False,
            "enable_mem_efficient": True,
        },
        None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
    }
else:
    # 从上下文管理库导入 nullcontext
    from contextlib import nullcontext

    # 设置 SDP_IS_AVAILABLE 为 False,表示 SDP 后端不可用
    SDP_IS_AVAILABLE = False
    # 将 sdp_kernel 设置为 nullcontext
    sdp_kernel = nullcontext
    # 打印提示信息,告知用户当前 PyTorch 版本不支持 SDP 后端
    print(
        f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
        f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
    )

# 尝试导入 xformers 和 xformers.ops
try:
    import xformers
    import xformers.ops

    # 如果导入成功,设置 XFORMERS_IS_AVAILABLE 为 True
    XFORMERS_IS_AVAILABLE = True
# 如果导入失败,设置 XFORMERS_IS_AVAILABLE 为 False,并打印提示信息
except:
    XFORMERS_IS_AVAILABLE = False
    print("no module 'xformers'. Processing without...")

# 从 diffusionmodules.util 模块导入 checkpoint 函数
from .diffusionmodules.util import checkpoint


# 定义 exists 函数,检查输入值是否存在
def exists(val):
    return val is not None


# 定义 uniq 函数,返回数组中唯一元素的键
def uniq(arr):
    return {el: True for el in arr}.keys()


# 定义 default 函数,如果 val 存在则返回它,否则返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


# 定义 max_neg_value 函数,返回给定张量类型的最大负值
def max_neg_value(t):
    return -torch.finfo(t.dtype).max


# 定义 init_ 函数,初始化张量
def init_(tensor):
    # 获取张量的最后一维的大小
    dim = tensor.shape[-1]
    # 计算标准差
    std = 1 / math.sqrt(dim)
    # 在区间 [-std, std] 内均匀初始化张量
    tensor.uniform_(-std, std)
    return tensor


# 定义 GEGLU 类,继承自 nn.Module
class GEGLU(nn.Module):
    # 初始化方法,设置输入和输出维度
    def __init__(self, dim_in, dim_out):
        super().__init__()
        # 创建一个线性投影层,将输入维度映射到两倍的输出维度
        self.proj = nn.Linear(dim_in, dim_out * 2)

    # 前向传播方法
    def forward(self, x):
        # 将输入通过投影层,分割为 x 和 gate
        x, gate = self.proj(x).chunk(2, dim=-1)
        # 返回 x 与 gate 的 GELU 激活后的乘积
        return x * F.gelu(gate)


# 定义 FeedForward 类,继承自 nn.Module
class FeedForward(nn.Module):
    # 初始化方法,设置维度、输出维度、乘法因子、是否使用 GEGLU 和 dropout 率
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
        super().__init__()
        # 计算内部维度
        inner_dim = int(dim * mult)
        # 如果 dim_out 未定义,使用 dim 作为默认值
        dim_out = default(dim_out, dim)
        # 根据是否使用 GEGLU 创建输入投影层
        project_in = (
            nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
            if not glu
            else GEGLU(dim, inner_dim)
        )

        # 定义网络结构
        self.net = nn.Sequential(
            project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
        )

    # 前向传播方法
    def forward(self, x):
        # 通过网络结构处理输入
        return self.net(x)


# 定义 zero_module 函数,清零模块的参数并返回该模块
def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    # 遍历模块的所有参数
    for p in module.parameters():
        # 将参数的梯度断开并清零
        p.detach().zero_()
    return module


# 定义 Normalize 函数,接收输入通道数
def Normalize(in_channels):
    # 返回一个 GroupNorm 实例,用于对输入进行分组归一化
        return torch.nn.GroupNorm(
            # 设置分组数量为 32
            num_groups=32, 
            # 设置输入通道数量
            num_channels=in_channels, 
            # 设置一个小的 epsilon 值以避免除零错误
            eps=1e-6, 
            # 设定 affine 为 True,以便进行可学习的仿射变换
            affine=True
        )
# 定义线性注意力机制的类,继承自 nn.Module
class LinearAttention(nn.Module):
    # 初始化方法,设置参数维度和头数
    def __init__(self, dim, heads=4, dim_head=32):
        # 调用父类构造函数
        super().__init__()
        # 设置头数
        self.heads = heads
        # 计算隐藏维度
        hidden_dim = dim_head * heads
        # 定义输入到 QKV 的卷积层,输出通道为 hidden_dim 的三倍
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        # 定义输出卷积层,输出通道为原始维度
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    # 前向传播方法
    def forward(self, x):
        # 获取输入的批次大小、通道数、高度和宽度
        b, c, h, w = x.shape
        # 将输入通过 QKV 卷积层
        qkv = self.to_qkv(x)
        # 将 QKV 的输出重排以分开 Q、K、V
        q, k, v = rearrange(
            qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
        )
        # 对 K 进行 softmax 归一化
        k = k.softmax(dim=-1)
        # 计算上下文向量,通过爱因斯坦求和约定
        context = torch.einsum("bhdn,bhen->bhde", k, v)
        # 使用上下文向量和 Q 计算输出
        out = torch.einsum("bhde,bhdn->bhen", context, q)
        # 重排输出以恢复到原始形状
        out = rearrange(
            out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
        )
        # 返回最终输出
        return self.to_out(out)


# 定义空间自注意力机制的类,继承自 nn.Module
class SpatialSelfAttention(nn.Module):
    # 初始化方法,设置输入通道
    def __init__(self, in_channels):
        # 调用父类构造函数
        super().__init__()
        # 存储输入通道数
        self.in_channels = in_channels

        # 创建归一化层
        self.norm = Normalize(in_channels)
        # 创建 Q、K、V 的卷积层,输出通道与输入通道相同
        self.q = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.k = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.v = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        # 创建输出的卷积层
        self.proj_out = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )

    # 前向传播方法
    def forward(self, x):
        # 初始化 h_ 为输入 x
        h_ = x
        # 对 h_ 进行归一化处理
        h_ = self.norm(h_)
        # 计算 Q、K、V
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # 计算注意力权重
        b, c, h, w = q.shape
        # 将 Q 重排为 (batch_size, h*w, c) 的形状
        q = rearrange(q, "b c h w -> b (h w) c")
        # 将 K 重排为 (batch_size, c, h*w) 的形状
        k = rearrange(k, "b c h w -> b c (h w)")
        # 计算 Q 和 K 的点积以获得权重
        w_ = torch.einsum("bij,bjk->bik", q, k)

        # 对权重进行缩放
        w_ = w_ * (int(c) ** (-0.5))
        # 对权重进行 softmax 归一化
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # 处理 V
        v = rearrange(v, "b c h w -> b c (h w)")
        # 重排权重 w_ 为 (batch_size, h*w, h*w) 的形状
        w_ = rearrange(w_, "b i j -> b j i")
        # 计算 h_,通过 V 和权重相乘
        h_ = torch.einsum("bij,bjk->bik", v, w_)
        # 将 h_ 重排回原始形状
        h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
        # 通过 proj_out 进行最终的线性变换
        h_ = self.proj_out(h_)

        # 返回输入与处理后的 h_ 的和
        return x + h_


# 定义交叉注意力机制的类,继承自 nn.Module
class CrossAttention(nn.Module):
    # 初始化方法,设置查询、上下文维度和其他参数
    def __init__(
        self,
        query_dim,
        context_dim=None,
        heads=8,
        dim_head=64,
        dropout=0.0,
        backend=None,
    ):
        # 调用父类构造函数
        super().__init__()
        # 计算内部维度
        inner_dim = dim_head * heads
        # 如果没有提供上下文维度,默认使用查询维度
        context_dim = default(context_dim, query_dim)

        # 设置缩放因子
        self.scale = dim_head**-0.5
        # 设置头数
        self.heads = heads

        # 定义 Q、K、V 的线性变换
        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        # 定义输出层,包括线性变换和 dropout
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
        )
        # 存储后端
        self.backend = backend
    # 定义前向传播函数,接收输入数据及相关参数
        def forward(
            self,
            x,
            context=None,
            mask=None,
            additional_tokens=None,
            n_times_crossframe_attn_in_self=0,
        ):
            # 获取注意力头的数量
            h = self.heads
    
            # 如果有额外的 tokens
            if additional_tokens is not None:
                # 获取输出序列开始时的掩码 token 数量
                n_tokens_to_mask = additional_tokens.shape[1]
                # 将额外的 token 添加到输入数据前
                x = torch.cat([additional_tokens, x], dim=1)
    
            # 通过线性变换生成查询向量
            q = self.to_q(x)
            # 使用默认值或输入作为上下文
            context = default(context, x)
            # 通过线性变换生成键向量
            k = self.to_k(context)
            # 通过线性变换生成值向量
            v = self.to_v(context)
    
            # 如果需要进行跨帧注意力
            if n_times_crossframe_attn_in_self:
                # 验证输入批次大小可以被跨帧次数整除
                assert x.shape[0] % n_times_crossframe_attn_in_self == 0
                # 计算每次跨帧的批次大小
                n_cp = x.shape[0] // n_times_crossframe_attn_in_self
                # 重复键向量以适应跨帧注意力
                k = repeat(
                    k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
                )
                # 重复值向量以适应跨帧注意力
                v = repeat(
                    v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
                )
    
            # 将查询、键、值向量重排为适合多头注意力的形状
            q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
    
            ## old
            """
            # 计算查询与键之间的相似度,并进行缩放
            sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
            # 删除查询和键以节省内存
            del q, k
    
            # 如果存在掩码
            if exists(mask):
                # 将掩码重排为适合的形状
                mask = rearrange(mask, 'b ... -> b (...)')
                # 获取相似度的最大负值
                max_neg_value = -torch.finfo(sim.dtype).max
                # 重复掩码以适应多头
                mask = repeat(mask, 'b j -> (b h) () j', h=h)
                # 使用掩码填充相似度矩阵
                sim.masked_fill_(~mask, max_neg_value)
    
            # 应用 softmax 计算注意力权重
            sim = sim.softmax(dim=-1)
    
            # 使用注意力权重加权值向量,生成输出
            out = einsum('b i j, b j d -> b i d', sim, v)
            """
            ## new
            # 使用指定的后端进行缩放点积注意力计算
            with sdp_kernel(**BACKEND_MAP[self.backend]):
                # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
                out = F.scaled_dot_product_attention(
                    q, k, v, attn_mask=mask
                )  # 默认缩放因子为 dim_head ** -0.5
    
            # 删除查询、键和值向量以释放内存
            del q, k, v
            # 将输出重排为适合最终输出的形状
            out = rearrange(out, "b h n d -> b n (h d)", h=h)
    
            # 如果有额外的 tokens
            if additional_tokens is not None:
                # 移除额外的 token
                out = out[:, n_tokens_to_mask:]
            # 返回最终的输出结果
            return self.to_out(out)
# 定义一个内存高效的交叉注意力模块,继承自 nn.Module
class MemoryEfficientCrossAttention(nn.Module):
    # 初始化方法,接受查询维度、上下文维度、头数、每个头的维度和丢弃率等参数
    def __init__(
        self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 计算每个头的内部维度,等于头数乘以每个头的维度
        inner_dim = dim_head * heads
        # 如果上下文维度未提供,则将其设置为查询维度
        context_dim = default(context_dim, query_dim)

        # 保存头数和每个头的维度到实例变量
        self.heads = heads
        self.dim_head = dim_head

        # 创建线性层,将查询输入转换为内部维度,不使用偏置
        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        # 创建线性层,将上下文输入转换为内部维度,不使用偏置
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        # 创建线性层,将上下文输入转换为内部维度,不使用偏置
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        # 创建一个顺序容器,包含将内部维度转换回查询维度的线性层和丢弃层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
        )
        # 初始化注意力操作为 None
        self.attention_op: Optional[Any] = None

    # 定义前向传播方法,接受输入、上下文、掩码和其他可选参数
    def forward(
        self,
        x,
        context=None,
        mask=None,
        additional_tokens=None,
        n_times_crossframe_attn_in_self=0,
    ):
        # 检查是否提供了额外的令牌
        if additional_tokens is not None:
            # 获取输出序列开头的被遮掩令牌数量
            n_tokens_to_mask = additional_tokens.shape[1]
            # 将额外的令牌与当前输入合并
            x = torch.cat([additional_tokens, x], dim=1)
        # 将输入转换为查询向量
        q = self.to_q(x)
        # 使用默认值或输入作为上下文
        context = default(context, x)
        # 将上下文转换为键向量
        k = self.to_k(context)
        # 将上下文转换为值向量
        v = self.to_v(context)

        # 检查是否需要进行跨帧注意力的重新编程
        if n_times_crossframe_attn_in_self:
            # 进行跨帧注意力的重新编程,参考 https://arxiv.org/abs/2303.13439
            assert x.shape[0] % n_times_crossframe_attn_in_self == 0
            # k的维度处理,使用每n_times_crossframe_attn_in_self帧的一个
            k = repeat(
                k[::n_times_crossframe_attn_in_self],
                "b ... -> (b n) ...",
                n=n_times_crossframe_attn_in_self,
            )
            # v的维度处理,使用每n_times_crossframe_attn_in_self帧的一个
            v = repeat(
                v[::n_times_crossframe_attn_in_self],
                "b ... -> (b n) ...",
                n=n_times_crossframe_attn_in_self,
            )

        # 获取批次大小和特征维度
        b, _, _ = q.shape
        # 对 q, k, v 进行维度调整和重塑
        q, k, v = map(
            lambda t: t.unsqueeze(3)  # 在最后一维添加一个新维度
            .reshape(b, t.shape[1], self.heads, self.dim_head)  # 重塑为(batch, seq_len, heads, dim_head)
            .permute(0, 2, 1, 3)  # 重新排列维度为(batch, heads, seq_len, dim_head)
            .reshape(b * self.heads, t.shape[1], self.dim_head)  # 再次重塑为(batch * heads, seq_len, dim_head)
            .contiguous(),  # 确保内存连续性
            (q, k, v),  # 对 q, k, v 进行相同处理
        )

        # 实际计算注意力,这个过程是最不可或缺的
        out = xformers.ops.memory_efficient_attention(
            q, k, v, attn_bias=None, op=self.attention_op
        )

        # TODO: 将这个直接用作注意力操作中的偏置
        if exists(mask):
            # 如果存在遮掩,抛出未实现异常
            raise NotImplementedError
        # 对输出进行维度调整,适配最终输出形状
        out = (
            out.unsqueeze(0)  # 在最前面添加一个新维度
            .reshape(b, self.heads, out.shape[1], self.dim_head)  # 重塑为(batch, heads, seq_len, dim_head)
            .permute(0, 2, 1, 3)  # 重新排列维度为(batch, seq_len, heads, dim_head)
            .reshape(b, out.shape[1], self.heads * self.dim_head)  # 再次重塑为(batch, seq_len, heads * dim_head)
        )
        # 如果有额外的令牌,则移除它们
        if additional_tokens is not None:
            out = out[:, n_tokens_to_mask:]  # 切除被遮掩的令牌部分
        # 将输出转换为最终的输出格式
        return self.to_out(out)
# 定义基本的变换器模块类,继承自 nn.Module
class BasicTransformerBlock(nn.Module):
    # 定义可用的注意力模式,映射到对应的类
    ATTENTION_MODES = {
        "softmax": CrossAttention,  # 普通注意力
        "softmax-xformers": MemoryEfficientCrossAttention,  # 记忆高效注意力
    }

    # 初始化方法,设置模型的参数
    def __init__(
        self,
        dim,  # 输入的维度
        n_heads,  # 注意力头的数量
        d_head,  # 每个注意力头的维度
        dropout=0.0,  # dropout 比例
        context_dim=None,  # 上下文的维度
        gated_ff=True,  # 是否使用门控前馈网络
        checkpoint=True,  # 是否启用检查点
        disable_self_attn=False,  # 是否禁用自注意力
        attn_mode="softmax",  # 注意力模式,默认为 softmax
        sdp_backend=None,  # 后端配置
    ):
        # 调用父类初始化方法
        super().__init__()
        # 检查给定的注意力模式是否有效
        assert attn_mode in self.ATTENTION_MODES
        # 如果使用的注意力模式不支持,回退到默认模式
        if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
            print(
                f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
                f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
            )
            attn_mode = "softmax"  # 回退到 softmax 模式
        # 如果使用普通注意力且不支持,给出提示并调整模式
        elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
            print(
                "We do not support vanilla attention anymore, as it is too expensive. Sorry."
            )
            # 确保已安装 xformers
            if not XFORMERS_IS_AVAILABLE:
                assert (
                    False
                ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
            else:
                print("Falling back to xformers efficient attention.")
                attn_mode = "softmax-xformers"  # 使用 xformers 模式
        # 根据最终的注意力模式选择对应的类
        attn_cls = self.ATTENTION_MODES[attn_mode]
        # 检查 PyTorch 版本是否支持指定的后端
        if version.parse(torch.__version__) >= version.parse("2.0.0"):
            assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
        else:
            assert sdp_backend is None  # 对于旧版本,后端必须为 None
        self.disable_self_attn = disable_self_attn  # 保存禁用自注意力的设置
        # 初始化第一个注意力层
        self.attn1 = attn_cls(
            query_dim=dim,  # 查询的维度
            heads=n_heads,  # 注意力头数量
            dim_head=d_head,  # 每个头的维度
            dropout=dropout,  # dropout 比例
            context_dim=context_dim if self.disable_self_attn else None,  # 上下文维度
            backend=sdp_backend,  # 后端配置
        )  # 如果未禁用自注意力,则为自注意力层
        # 初始化前馈网络
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)  # 前馈网络配置
        # 初始化第二个注意力层
        self.attn2 = attn_cls(
            query_dim=dim,  # 查询的维度
            context_dim=context_dim,  # 上下文维度
            heads=n_heads,  # 注意力头数量
            dim_head=d_head,  # 每个头的维度
            dropout=dropout,  # dropout 比例
            backend=sdp_backend,  # 后端配置
        )  # 如果上下文维度为 None,则为自注意力层
        # 初始化层归一化层
        self.norm1 = nn.LayerNorm(dim)  # 第一个归一化层
        self.norm2 = nn.LayerNorm(dim)  # 第二个归一化层
        self.norm3 = nn.LayerNorm(dim)  # 第三个归一化层
        self.checkpoint = checkpoint  # 保存检查点设置
        # 如果启用检查点,输出相关信息(代码暂时注释掉)
        # if self.checkpoint:
        #     print(f"{self.__class__.__name__} is using checkpointing")

    # 前向传播方法,定义输入和上下文的处理
    def forward(
        self, x,  # 输入数据
        context=None,  # 上下文数据
        additional_tokens=None,  # 额外的 token
        n_times_crossframe_attn_in_self=0  # 跨帧自注意力的次数
    ):
        # 创建一个字典 kwargs,初始包含键 "x" 和参数 x 的值
        kwargs = {"x": x}

        # 如果 context 参数不为 None,则将其添加到 kwargs 字典中
        if context is not None:
            kwargs.update({"context": context})

        # 如果 additional_tokens 参数不为 None,则将其添加到 kwargs 字典中
        if additional_tokens is not None:
            kwargs.update({"additional_tokens": additional_tokens})

        # 如果 n_times_crossframe_attn_in_self 为真,则将其添加到 kwargs 字典中
        if n_times_crossframe_attn_in_self:
            kwargs.update(
                {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
            )

        # 返回调用 checkpoint 函数,传入 _forward 方法和相关参数
        return checkpoint(
            self._forward, (x, context), self.parameters(), self.checkpoint
        )

    def _forward(
        # 定义 _forward 方法,接收 x、context、additional_tokens 和 n_times_crossframe_attn_in_self 参数
        self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
    ):
        # 对 x 进行规范化,然后通过 attn1 方法进行自注意力计算,并根据条件选择 context 和其他参数
        x = (
            self.attn1(
                self.norm1(x),
                context=context if self.disable_self_attn else None,
                additional_tokens=additional_tokens,
                n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
                if not self.disable_self_attn
                else 0,
            )
            + x  # 将 self.attn1 的输出与原始 x 相加
        )
        # 继续对 x 进行规范化,通过 attn2 方法进行自注意力计算
        x = (
            self.attn2(
                self.norm2(x), context=context, additional_tokens=additional_tokens
            )
            + x  # 将 self.attn2 的输出与当前 x 相加
        )
        # 对 x 进行规范化,然后通过前馈网络 ff 处理,再与原始 x 相加
        x = self.ff(self.norm3(x)) + x
        # 返回处理后的 x
        return x
# 定义基本的单层变换器块类,继承自 nn.Module
class BasicTransformerSingleLayerBlock(nn.Module):
    # 定义不同的注意力模式及其对应的类
    ATTENTION_MODES = {
        "softmax": CrossAttention,  # 标准注意力
        "softmax-xformers": MemoryEfficientCrossAttention  # 针对 A100s 的优化版本,速度可能略慢
        # (todo 可能依赖于 head_dim,需检查,对于 dim!=[16,32,64,128] 时退回到半优化内核)
    }

    # 初始化方法,设置基本参数
    def __init__(
        self,
        dim,  # 特征维度
        n_heads,  # 注意力头数
        d_head,  # 每个注意力头的维度
        dropout=0.0,  # 丢弃率
        context_dim=None,  # 上下文维度(可选)
        gated_ff=True,  # 是否使用门控前馈网络
        checkpoint=True,  # 是否启用检查点
        attn_mode="softmax",  # 注意力模式
    ):
        # 调用父类构造函数
        super().__init__()
        # 确保所选的注意力模式在定义的模式中
        assert attn_mode in self.ATTENTION_MODES
        # 获取对应的注意力类
        attn_cls = self.ATTENTION_MODES[attn_mode]
        # 初始化注意力层
        self.attn1 = attn_cls(
            query_dim=dim,  # 查询维度
            heads=n_heads,  # 头数
            dim_head=d_head,  # 每头的维度
            dropout=dropout,  # 丢弃率
            context_dim=context_dim,  # 上下文维度
        )
        # 初始化前馈网络
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        # 初始化层归一化
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        # 设置检查点标志
        self.checkpoint = checkpoint

    # 前向传播方法
    def forward(self, x, context=None):
        # 使用检查点机制来进行前向传播
        return checkpoint(
            self._forward, (x, context), self.parameters(), self.checkpoint
        )

    # 实际的前向传播实现
    def _forward(self, x, context=None):
        # 通过注意力层进行处理并添加残差连接
        x = self.attn1(self.norm1(x), context=context) + x
        # 通过前馈网络处理并添加残差连接
        x = self.ff(self.norm2(x)) + x
        # 返回处理后的结果
        return x


# 定义空间变换器类,继承自 nn.Module
class SpatialTransformer(nn.Module):
    """
    适用于图像数据的变换器块。
    首先,将输入(即嵌入)投影并重塑为 b, t, d 形状。
    然后应用标准的变换器操作。
    最后,重塑为图像。
    新增:使用线性层提高效率,而不是 1x1 卷积。
    """

    # 初始化方法,设置变换器块的基本参数
    def __init__(
        self,
        in_channels,  # 输入通道数
        n_heads,  # 注意力头数
        d_head,  # 每个头的维度
        depth=1,  # 变换器块的深度
        dropout=0.0,  # 丢弃率
        context_dim=None,  # 上下文维度(可选)
        disable_self_attn=False,  # 是否禁用自注意力
        use_linear=False,  # 是否使用线性层
        attn_type="softmax",  # 注意力类型
        use_checkpoint=True,  # 是否使用检查点
        # sdp_backend=SDPBackend.FLASH_ATTENTION  # 可选的 SDP 后端
        sdp_backend=None,  # SDP 后端设置(默认值为 None)
    # 初始化类,调用父类构造函数
        ):
            super().__init__()
            # 打印构造信息(已注释)
            # print(
            #     f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
            # )
            # 从 omegaconf 导入 ListConfig 类
            from omegaconf import ListConfig
    
            # 如果 context_dim 存在且不是列表或 ListConfig 类型
            if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
                # 将 context_dim 转换为列表
                context_dim = [context_dim]
            # 如果 context_dim 存在且是列表
            if exists(context_dim) and isinstance(context_dim, list):
                # 检查 depth 是否与 context_dim 的长度匹配
                if depth != len(context_dim):
                    # 打印警告信息(已注释)
                    # print(
                    #     f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
                    #     f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
                    # )
                    # 确保所有 context_dim 元素相同
                    assert all(
                        map(lambda x: x == context_dim[0], context_dim)
                    ), "need homogenous context_dim to match depth automatically"
                    # 如果不一致,设置 context_dim 为相同值的列表
                    context_dim = depth * [context_dim[0]]
            # 如果 context_dim 为 None
            elif context_dim is None:
                # 创建与 depth 长度相同的 None 列表
                context_dim = [None] * depth
            # 保存输入通道数
            self.in_channels = in_channels
            # 计算内部维度
            inner_dim = n_heads * d_head
            # 归一化层
            self.norm = Normalize(in_channels)
            # 如果不使用线性层
            if not use_linear:
                # 使用卷积层进行输入投影
                self.proj_in = nn.Conv2d(
                    in_channels, inner_dim, kernel_size=1, stride=1, padding=0
                )
            else:
                # 使用线性层进行输入投影
                self.proj_in = nn.Linear(in_channels, inner_dim)
    
            # 创建变压器模块列表
            self.transformer_blocks = nn.ModuleList(
                [
                    BasicTransformerBlock(
                        inner_dim,
                        n_heads,
                        d_head,
                        dropout=dropout,
                        context_dim=context_dim[d],
                        disable_self_attn=disable_self_attn,
                        attn_mode=attn_type,
                        checkpoint=use_checkpoint,
                        sdp_backend=sdp_backend,
                    )
                    # 遍历深度范围,生成多个变压器块
                    for d in range(depth)
                ]
            )
            # 如果不使用线性层
            if not use_linear:
                # 使用零初始化卷积层进行输出投影
                self.proj_out = zero_module(
                    nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
                )
            else:
                # 使用零初始化线性层进行输出投影(已注释)
                # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
                self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
            # 保存是否使用线性层的标志
            self.use_linear = use_linear
    # 定义前向传播函数,接收输入 x 和可选的上下文 context
    def forward(self, x, context=None):
        # 注意:如果没有提供上下文,交叉注意力默认为自注意力
        if not isinstance(context, list):
            # 将上下文包装为列表,方便后续处理
            context = [context]
        # 获取输入张量的形状:批量大小 b,通道数 c,高 h,宽 w
        b, c, h, w = x.shape
        # 保存输入张量的原始值以便后续使用
        x_in = x
        # 对输入进行归一化处理
        x = self.norm(x)
        # 如果不使用线性变换,则进行投影变换
        if not self.use_linear:
            x = self.proj_in(x)
        # 重新排列张量的维度,将其从 (b, c, h, w) 变为 (b, h*w, c)
        x = rearrange(x, "b c h w -> b (h w) c").contiguous()
        # 如果使用线性变换,则再次进行投影变换
        if self.use_linear:
            x = self.proj_in(x)
        # 遍历所有的变换块
        for i, block in enumerate(self.transformer_blocks):
            # 如果不是第一个块且上下文长度为1,则使用同一个上下文
            if i > 0 and len(context) == 1:
                i = 0  # 每个块使用相同的上下文
            # 将输入传入当前变换块,并使用相应的上下文
            x = block(x, context=context[i])
        # 如果使用线性变换,则进行输出投影变换
        if self.use_linear:
            x = self.proj_out(x)
        # 重新排列张量的维度,将其从 (b, h*w, c) 变为 (b, c, h, w)
        x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
        # 如果不使用线性变换,则进行输出投影变换
        if not self.use_linear:
            x = self.proj_out(x)
        # 返回处理后的张量与原始输入的和
        return x + x_in

.\cogview3-finetune\sat\sgm\modules\autoencoding\losses\__init__.py

# 导入类型提示 Any 和 Union
from typing import Any, Union

# 导入 PyTorch 及其神经网络模块
import torch
import torch.nn as nn
# 从 einops 导入 rearrange 函数
from einops import rearrange

# 导入自定义工具函数和类
from ....util import default, instantiate_from_config
# 从 lpips 库导入 LPIPS 损失类
from ..lpips.loss.lpips import LPIPS
# 从 lpips 模型导入 NLayerDiscriminator 和权重初始化函数
from ..lpips.model.model import NLayerDiscriminator, weights_init
# 从 vqperceptual 模块导入两种损失函数
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss


# 定义 adopt_weight 函数,调整权重值
def adopt_weight(weight, global_step, threshold=0, value=0.0):
    # 如果全局步数小于阈值,则将权重设为给定值
    if global_step < threshold:
        weight = value
    # 返回调整后的权重
    return weight


# 定义 LatentLPIPS 类,继承自 nn.Module
class LatentLPIPS(nn.Module):
    # 初始化方法,设置相关参数
    def __init__(
        self,
        decoder_config,
        perceptual_weight=1.0,
        latent_weight=1.0,
        scale_input_to_tgt_size=False,
        scale_tgt_to_input_size=False,
        perceptual_weight_on_inputs=0.0,
    ):
        # 调用父类构造方法
        super().__init__()
        # 设置输入大小缩放标志
        self.scale_input_to_tgt_size = scale_input_to_tgt_size
        self.scale_tgt_to_input_size = scale_tgt_to_input_size
        # 初始化解码器
        self.init_decoder(decoder_config)
        # 初始化感知损失模型,并设置为评估模式
        self.perceptual_loss = LPIPS().eval()
        # 设置感知损失和潜在损失的权重
        self.perceptual_weight = perceptual_weight
        self.latent_weight = latent_weight
        # 设置对输入的感知权重
        self.perceptual_weight_on_inputs = perceptual_weight_on_inputs

    # 定义初始化解码器的方法
    def init_decoder(self, config):
        # 从配置实例化解码器
        self.decoder = instantiate_from_config(config)
        # 如果解码器有 encoder 属性,则删除该属性
        if hasattr(self.decoder, "encoder"):
            del self.decoder.encoder
    # 定义前向传播函数,接收潜在输入、潜在预测、图像输入和数据集切分信息
    def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
        # 初始化一个字典用于记录日志信息
        log = dict()
        # 计算潜在输入与潜在预测之间的均方差损失
        loss = (latent_inputs - latent_predictions) ** 2
        # 将均方差损失的平均值添加到日志中,使用切分名称作为键
        log[f"{split}/latent_l2_loss"] = loss.mean().detach()
        # 初始化图像重建变量
        image_reconstructions = None
        # 如果感知损失权重大于0,则进行感知损失的计算
        if self.perceptual_weight > 0.0:
            # 解码潜在预测生成图像重建
            image_reconstructions = self.decoder.decode(latent_predictions)
            # 解码潜在输入生成目标图像
            image_targets = self.decoder.decode(latent_inputs)
            # 计算感知损失
            perceptual_loss = self.perceptual_loss(
                image_targets.contiguous(), image_reconstructions.contiguous()
            )
            # 综合潜在损失和感知损失,更新总损失
            loss = (
                self.latent_weight * loss.mean()
                + self.perceptual_weight * perceptual_loss.mean()
            )
            # 将感知损失的平均值添加到日志中
            log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
    
        # 如果感知权重在输入上大于0,则进行相应的处理
        if self.perceptual_weight_on_inputs > 0.0:
            # 如果重建图像为空,则解码潜在预测生成图像重建
            image_reconstructions = default(
                image_reconstructions, self.decoder.decode(latent_predictions)
            )
            # 如果需要将输入图像缩放到目标图像大小
            if self.scale_input_to_tgt_size:
                image_inputs = torch.nn.functional.interpolate(
                    image_inputs,
                    image_reconstructions.shape[2:],
                    mode="bicubic",  # 使用双三次插值法
                    antialias=True,  # 使用抗锯齿
                )
            # 如果需要将目标图像缩放到输入图像大小
            elif self.scale_tgt_to_input_size:
                image_reconstructions = torch.nn.functional.interpolate(
                    image_reconstructions,
                    image_inputs.shape[2:],
                    mode="bicubic",  # 使用双三次插值法
                    antialias=True,  # 使用抗锯齿
                )
    
            # 计算与输入图像的感知损失
            perceptual_loss2 = self.perceptual_loss(
                image_inputs.contiguous(), image_reconstructions.contiguous()
            )
            # 更新总损失,加入输入的感知损失
            loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
            # 将输入的感知损失的平均值添加到日志中
            log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
        # 返回总损失和日志信息
        return loss, log
# 定义一个带有判别器的通用 LPIPS 类,继承自 nn.Module
class GeneralLPIPSWithDiscriminator(nn.Module):
    # 初始化方法,接收多个参数以配置模型
    def __init__(
        self,
        disc_start: int,  # 判别器开始训练的迭代次数
        logvar_init: float = 0.0,  # 日志方差的初始值
        pixelloss_weight=1.0,  # 像素损失的权重
        disc_num_layers: int = 3,  # 判别器的层数
        disc_in_channels: int = 3,  # 判别器输入的通道数
        disc_factor: float = 1.0,  # 判别器的缩放因子
        disc_weight: float = 1.0,  # 判别器损失的权重
        perceptual_weight: float = 1.0,  # 感知损失的权重
        disc_loss: str = "hinge",  # 判别器使用的损失类型
        scale_input_to_tgt_size: bool = False,  # 是否将输入缩放到目标大小
        dims: int = 2,  # 数据的维度
        learn_logvar: bool = False,  # 是否学习日志方差
        regularization_weights: Union[None, dict] = None,  # 正则化权重
    ):
        # 调用父类的初始化方法
        super().__init__()
        self.dims = dims  # 保存维度信息
        # 如果维度大于2,打印警告信息
        if self.dims > 2:
            print(
                f"running with dims={dims}. This means that for perceptual loss calculation, "
                f"the LPIPS loss will be applied to each frame independently. "
            )
        self.scale_input_to_tgt_size = scale_input_to_tgt_size  # 保存输入缩放标志
        # 确保判别器损失类型为 hinge 或 vanilla
        assert disc_loss in ["hinge", "vanilla"]
        self.pixel_weight = pixelloss_weight  # 保存像素损失权重
        self.perceptual_loss = LPIPS().eval()  # 初始化 LPIPS 感知损失并设置为评估模式
        self.perceptual_weight = perceptual_weight  # 保存感知损失权重
        # 输出日志方差,作为可学习的参数
        self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
        self.learn_logvar = learn_logvar  # 保存是否学习日志方差的标志

        # 初始化 NLayerDiscriminator 作为判别器
        self.discriminator = NLayerDiscriminator(
            input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
        ).apply(weights_init)  # 应用权重初始化
        self.discriminator_iter_start = disc_start  # 保存判别器开始训练的迭代次数
        # 根据损失类型选择合适的损失函数
        self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
        self.disc_factor = disc_factor  # 保存判别器缩放因子
        self.discriminator_weight = disc_weight  # 保存判别器损失权重
        self.regularization_weights = default(regularization_weights, {})  # 设置正则化权重,默认为空字典

    # 获取可训练的参数
    def get_trainable_parameters(self) -> Any:
        return self.discriminator.parameters()  # 返回判别器的参数

    # 获取可训练的自编码器参数
    def get_trainable_autoencoder_parameters(self) -> Any:
        if self.learn_logvar:  # 如果学习日志方差
            yield self.logvar  # 生成日志方差
        yield from ()  # 返回空生成器

    # 计算自适应权重
    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
        if last_layer is not None:  # 如果提供了最后一层
            # 计算负对数似然损失的梯度
            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
            # 计算生成器损失的梯度
            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
        else:
            # 如果没有提供最后一层,使用类属性中的最后一层
            nll_grads = torch.autograd.grad(
                nll_loss, self.last_layer[0], retain_graph=True
            )[0]
            g_grads = torch.autograd.grad(
                g_loss, self.last_layer[0], retain_graph=True
            )[0]

        # 计算判别器权重
        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()  # 限制权重范围并分离计算图
        d_weight = d_weight * self.discriminator_weight  # 应用判别器权重
        return d_weight  # 返回计算得到的权重

    # 前向传播方法
    def forward(
        self,
        regularization_log,  # 正则化日志
        inputs,  # 输入数据
        reconstructions,  # 重建数据
        optimizer_idx,  # 优化器索引
        global_step,  # 全局步数
        last_layer=None,  # 最后一层(可选)
        split="train",  # 数据集划分(训练或验证)
        weights=None,  # 权重(可选)

.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\loss\lpips.py

# 从 https://github.com/richzhang/PerceptualSimilarity/tree/master/models 中剥离的版本
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""

# 从 collections 模块导入 namedtuple
from collections import namedtuple

# 导入 PyTorch 和神经网络模块
import torch
import torch.nn as nn
# 从 torchvision 导入模型
from torchvision import models

# 从上层目录的 util 模块导入获取检查点路径的函数
from ..util import get_ckpt_path


# 定义 LPIPS 类,继承自 nn.Module
class LPIPS(nn.Module):
    # 学习的感知度量
    def __init__(self, use_dropout=True):
        # 调用父类构造函数
        super().__init__()
        # 初始化缩放层
        self.scaling_layer = ScalingLayer()
        # 定义特征通道数量,针对 VGG16 特征
        self.chns = [64, 128, 256, 512, 512]  # vg16 features
        # 加载预训练的 VGG16 模型,不计算其梯度
        self.net = vgg16(pretrained=True, requires_grad=False)
        # 为每个通道初始化线性层
        self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
        self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
        self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
        self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
        self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
        # 从预训练模型中加载权重
        self.load_from_pretrained()
        # 禁用所有参数的梯度更新
        for param in self.parameters():
            param.requires_grad = False

    # 从预训练模型加载权重的方法
    def load_from_pretrained(self, name="vgg_lpips"):
        # 获取检查点文件的路径
        ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
        # 加载权重到当前模型中,严格匹配状态字典
        self.load_state_dict(
            torch.load(ckpt, map_location=torch.device("cpu")), strict=False
        )
        # 打印加载的模型路径
        print("loaded pretrained LPIPS loss from {}".format(ckpt))

    # 类方法,用于从预训练模型创建实例
    @classmethod
    def from_pretrained(cls, name="vgg_lpips"):
        # 检查模型名称是否有效
        if name != "vgg_lpips":
            raise NotImplementedError
        # 创建当前类的实例
        model = cls()
        # 获取检查点路径
        ckpt = get_ckpt_path(name)
        # 加载权重到模型中
        model.load_state_dict(
            torch.load(ckpt, map_location=torch.device("cpu")), strict=False
        )
        # 返回模型实例
        return model

    # 前向传播方法
    def forward(self, input, target):
        # 对输入和目标进行缩放处理
        in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
        # 通过网络提取特征
        outs0, outs1 = self.net(in0_input), self.net(in1_input)
        # 初始化特征和差异字典
        feats0, feats1, diffs = {}, {}, {}
        # 收集线性层
        lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
        # 遍历特征通道
        for kk in range(len(self.chns)):
            # 标准化特征并计算差异
            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
                outs1[kk]
            )
            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

        # 计算每个通道的结果
        res = [
            spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
            for kk in range(len(self.chns))
        ]
        # 初始化最终值
        val = res[0]
        # 累加结果
        for l in range(1, len(self.chns)):
            val += res[l]
        # 返回最终的值
        return val


# 定义缩放层类,继承自 nn.Module
class ScalingLayer(nn.Module):
    # 构造函数
    def __init__(self):
        # 调用父类构造函数
        super(ScalingLayer, self).__init__()
        # 注册偏移量的缓冲区
        self.register_buffer(
            "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
        )
        # 注册缩放因子的缓冲区
        self.register_buffer(
            "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
        )

    # 前向传播方法
    def forward(self, inp):
        # 进行缩放操作
        return (inp - self.shift) / self.scale


# 定义线性层类,继承自 nn.Module
class NetLinLayer(nn.Module):
    """一个单独的线性层,执行 1x1 卷积"""
    # 初始化网络层,接收输入通道数、输出通道数和是否使用 dropout
        def __init__(self, chn_in, chn_out=1, use_dropout=False):
            # 调用父类的初始化方法
            super(NetLinLayer, self).__init__()
            # 根据是否使用 dropout 创建层列表
            layers = (
                [
                    nn.Dropout(),  # 添加 dropout 层
                ]
                if (use_dropout)  # 如果使用 dropout
                else []  # 否则为空列表
            )
            # 添加卷积层到层列表,输入通道为 chn_in,输出通道为 chn_out,卷积核大小为 1
            layers += [
                nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
            ]
            # 将层列表包装成一个顺序容器,便于按顺序调用各层
            self.model = nn.Sequential(*layers)
# 定义一个 VGG16 类,继承自 PyTorch 的 Module 类
class vgg16(torch.nn.Module):
    # 初始化函数,接受是否需要梯度和是否使用预训练模型的参数
    def __init__(self, requires_grad=False, pretrained=True):
        # 调用父类的初始化函数
        super(vgg16, self).__init__()
        # 获取预训练的 VGG16 特征提取层
        vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
        # 定义五个序列层
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        # 设置切片的数量为 5
        self.N_slices = 5
        # 将前4层添加到 slice1 中
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        # 将第 4 到 8 层添加到 slice2 中
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        # 将第 9 到 15 层添加到 slice3 中
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        # 将第 16 到 22 层添加到 slice4 中
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        # 将第 23 到 29 层添加到 slice5 中
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        # 如果不需要梯度,则冻结所有参数
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    # 定义前向传播函数
    def forward(self, X):
        # 将输入通过 slice1 层
        h = self.slice1(X)
        # 保存第一层的输出
        h_relu1_2 = h
        # 将输出通过 slice2 层
        h = self.slice2(h)
        # 保存第二层的输出
        h_relu2_2 = h
        # 将输出通过 slice3 层
        h = self.slice3(h)
        # 保存第三层的输出
        h_relu3_3 = h
        # 将输出通过 slice4 层
        h = self.slice4(h)
        # 保存第四层的输出
        h_relu4_3 = h
        # 将输出通过 slice5 层
        h = self.slice5(h)
        # 保存第五层的输出
        h_relu5_3 = h
        # 创建一个命名元组来存储不同层的输出
        vgg_outputs = namedtuple(
            "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
        )
        # 将各层输出组合成一个元组
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
        # 返回元组
        return out


# 定义一个函数,用于规范化张量
def normalize_tensor(x, eps=1e-10):
    # 计算张量的 L2 范数
    norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
    # 返回规范化后的张量,避免除以零
    return x / (norm_factor + eps)


# 定义一个空间平均函数
def spatial_average(x, keepdim=True):
    # 在高和宽维度上计算平均值
    return x.mean([2, 3], keepdim=keepdim)

.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\loss\__init__.py


.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\model\model.py

# 导入 functools 模块,用于函数工具
import functools

# 导入 nn 模块,用于构建神经网络
import torch.nn as nn

# 从上级目录的 util 模块导入 ActNorm
from ..util import ActNorm


# 定义权重初始化函数
def weights_init(m):
    # 获取模块的类名
    classname = m.__class__.__name__
    # 如果类名包含 "Conv",则初始化卷积层的权重
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    # 如果类名包含 "BatchNorm",则初始化批归一化层的权重和偏置
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


# 定义 NLayerDiscriminator 类,继承自 nn.Module
class NLayerDiscriminator(nn.Module):
    """定义一个 PatchGAN 判别器,如 Pix2Pix 所示
    --> 参见 https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
    """

    # 构造函数,初始化参数
    def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
        """构造一个 PatchGAN 判别器
        参数:
            input_nc (int)  -- 输入图像的通道数
            ndf (int)       -- 最后一个卷积层中的滤波器数量
            n_layers (int)  -- 判别器中的卷积层数量
            norm_layer      -- 归一化层
        """
        # 调用父类的构造函数
        super(NLayerDiscriminator, self).__init__()
        # 根据是否使用 ActNorm 来选择归一化层
        if not use_actnorm:
            norm_layer = nn.BatchNorm2d
        else:
            norm_layer = ActNorm
        # 如果归一化层是 functools.partial 类型,判断是否使用偏置
        if (
            type(norm_layer) == functools.partial
        ):  # 不需要使用偏置,因为 BatchNorm2d 有仿射参数
            use_bias = norm_layer.func != nn.BatchNorm2d
        else:
            use_bias = norm_layer != nn.BatchNorm2d

        kw = 4  # 卷积核的大小
        padw = 1  # 卷积层的填充
        # 初始化序列,添加第一个卷积层和 LeakyReLU 激活函数
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True),
        ]
        nf_mult = 1  # 当前滤波器倍数
        nf_mult_prev = 1  # 上一层的滤波器倍数
        # 逐渐增加滤波器数量,构建后续的卷积层
        for n in range(1, n_layers):  # 逐渐增加滤波器数量
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)  # 滤波器倍数最大为 8
            sequence += [
                nn.Conv2d(
                    ndf * nf_mult_prev,  # 输入通道数
                    ndf * nf_mult,  # 输出通道数
                    kernel_size=kw,  # 卷积核大小
                    stride=2,  # 步幅
                    padding=padw,  # 填充
                    bias=use_bias,  # 是否使用偏置
                ),
                norm_layer(ndf * nf_mult),  # 添加归一化层
                nn.LeakyReLU(0.2, True),  # 添加 LeakyReLU 激活函数
            ]

        nf_mult_prev = nf_mult  # 更新上一个滤波器倍数
        nf_mult = min(2**n_layers, 8)  # 计算最终滤波器倍数
        sequence += [
            nn.Conv2d(
                ndf * nf_mult_prev,  # 输入通道数
                ndf * nf_mult,  # 输出通道数
                kernel_size=kw,  # 卷积核大小
                stride=1,  # 步幅
                padding=padw,  # 填充
                bias=use_bias,  # 是否使用偏置
            ),
            norm_layer(ndf * nf_mult),  # 添加归一化层
            nn.LeakyReLU(0.2, True),  # 添加 LeakyReLU 激活函数
        ]

        sequence += [
            nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
        ]  # 输出 1 通道的预测图
        # 将所有层组合成一个序列
        self.main = nn.Sequential(*sequence)

    # 定义前向传播函数
    def forward(self, input):
        """标准前向传播。"""
        return self.main(input)  # 将输入传入序列并返回输出

.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\model\__init__.py


.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\util.py

# 导入所需的库
import hashlib  # 导入 hashlib 库用于计算文件的 MD5 哈希
import os  # 导入 os 库用于文件和目录操作

import requests  # 导入 requests 库用于发送 HTTP 请求
import torch  # 导入 PyTorch 库用于深度学习
import torch.nn as nn  # 导入 nn 模块用于构建神经网络
from tqdm import tqdm  # 导入 tqdm 库用于显示进度条

# 定义模型 URL 映射字典
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}

# 定义模型检查点文件名映射字典
CKPT_MAP = {"vgg_lpips": "vgg.pth"}

# 定义模型 MD5 哈希值映射字典
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}


# 定义下载函数
def download(url, local_path, chunk_size=1024):
    # 创建本地路径的父目录,如果不存在则创建
    os.makedirs(os.path.split(local_path)[0], exist_ok=True)
    # 发送 GET 请求以流式下载文件
    with requests.get(url, stream=True) as r:
        # 获取响应头中的内容长度
        total_size = int(r.headers.get("content-length", 0))
        # 使用 tqdm 显示下载进度条
        with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
            # 以二进制写入模式打开本地文件
            with open(local_path, "wb") as f:
                # 分块读取响应内容
                for data in r.iter_content(chunk_size=chunk_size):
                    # 如果读取到数据,则写入文件
                    if data:
                        f.write(data)  # 写入数据到文件
                        pbar.update(chunk_size)  # 更新进度条


# 定义 MD5 哈希函数
def md5_hash(path):
    # 以二进制读取模式打开指定路径的文件
    with open(path, "rb") as f:
        content = f.read()  # 读取文件内容
    # 返回文件内容的 MD5 哈希值
    return hashlib.md5(content).hexdigest()


# 定义获取检查点路径的函数
def get_ckpt_path(name, root, check=False):
    # 确保给定的模型名称在 URL 映射中
    assert name in URL_MAP
    # 组合根目录和检查点文件名,形成完整路径
    path = os.path.join(root, CKPT_MAP[name])
    # 检查文件是否存在或是否需要重新下载
    if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
        # 打印下载信息
        print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
        download(URL_MAP[name], path)  # 下载文件
        md5 = md5_hash(path)  # 计算下载文件的 MD5 哈希值
        # 确保下载的文件 MD5 值与预期匹配
        assert md5 == MD5_MAP[name], md5
    return path  # 返回检查点路径


# 定义一个自定义的神经网络模块
class ActNorm(nn.Module):
    # 构造函数,初始化参数
    def __init__(
        self, num_features, logdet=False, affine=True, allow_reverse_init=False
    ):
        assert affine  # 确保启用仿射变换
        super().__init__()  # 调用父类构造函数
        self.logdet = logdet  # 保存 logdet 标志
        # 定义可学习的均值参数
        self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
        # 定义可学习的缩放参数
        self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.allow_reverse_init = allow_reverse_init  # 保存是否允许反向初始化标志

        # 注册一个缓冲区,用于记录初始化状态
        self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))

    # 定义初始化函数
    def initialize(self, input):
        with torch.no_grad():  # 在不计算梯度的上下文中
            # 将输入张量重排列并展平
            flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
            # 计算展平后的均值
            mean = (
                flatten.mean(1)
                .unsqueeze(1)
                .unsqueeze(2)
                .unsqueeze(3)
                .permute(1, 0, 2, 3)
            )
            # 计算展平后的标准差
            std = (
                flatten.std(1)
                .unsqueeze(1)
                .unsqueeze(2)
                .unsqueeze(3)
                .permute(1, 0, 2, 3)
            )

            # 将均值复制到 loc 参数
            self.loc.data.copy_(-mean)
            # 将标准差的倒数复制到 scale 参数
            self.scale.data.copy_(1 / (std + 1e-6))
    # 定义前向传播函数,接受输入和反向标志
        def forward(self, input, reverse=False):
            # 如果反向标志为真,调用反向函数处理输入
            if reverse:
                return self.reverse(input)
            # 检查输入的形状是否为二维
            if len(input.shape) == 2:
                # 将二维输入扩展为四维,增加两个新的维度
                input = input[:, :, None, None]
                squeeze = True  # 标记为需要压缩
            else:
                squeeze = False  # 不需要压缩
    
            # 解包输入的高度和宽度
            _, _, height, width = input.shape
    
            # 如果处于训练状态且尚未初始化
            if self.training and self.initialized.item() == 0:
                # 初始化参数
                self.initialize(input)
                # 标记为已初始化
                self.initialized.fill_(1)
    
            # 根据比例因子和位移量调整输入
            h = self.scale * (input + self.loc)
    
            # 如果需要压缩,移除最后两个维度
            if squeeze:
                h = h.squeeze(-1).squeeze(-1)
    
            # 如果需要计算对数行列式
            if self.logdet:
                # 计算比例因子的绝对值的对数
                log_abs = torch.log(torch.abs(self.scale))
                # 计算对数行列式的值
                logdet = height * width * torch.sum(log_abs)
                # 生成与批量大小相同的对数行列式张量
                logdet = logdet * torch.ones(input.shape[0]).to(input)
                # 返回调整后的输出和对数行列式
                return h, logdet
    
            # 返回调整后的输出
            return h
    
    # 定义反向传播函数,接受输出
        def reverse(self, output):
            # 如果处于训练状态且尚未初始化
            if self.training and self.initialized.item() == 0:
                # 如果不允许在反向方向初始化,则抛出错误
                if not self.allow_reverse_init:
                    raise RuntimeError(
                        "Initializing ActNorm in reverse direction is "
                        "disabled by default. Use allow_reverse_init=True to enable."
                    )
                else:
                    # 初始化参数
                    self.initialize(output)
                    # 标记为已初始化
                    self.initialized.fill_(1)
    
            # 检查输出的形状是否为二维
            if len(output.shape) == 2:
                # 将二维输出扩展为四维,增加两个新的维度
                output = output[:, :, None, None]
                squeeze = True  # 标记为需要压缩
            else:
                squeeze = False  # 不需要压缩
    
            # 根据比例因子和位移量调整输出
            h = output / self.scale - self.loc
    
            # 如果需要压缩,移除最后两个维度
            if squeeze:
                h = h.squeeze(-1).squeeze(-1)
            # 返回调整后的输出
            return h

.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\vqperceptual.py

# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的功能性模块
import torch.nn.functional as F


# 定义对抗训练中的判别器损失函数(hinge 损失)
def hinge_d_loss(logits_real, logits_fake):
    # 计算真实样本的损失,使用 ReLU 激活函数
    loss_real = torch.mean(F.relu(1.0 - logits_real))
    # 计算假样本的损失,使用 ReLU 激活函数
    loss_fake = torch.mean(F.relu(1.0 + logits_fake))
    # 计算总的判别器损失,取真实和假样本损失的平均值
    d_loss = 0.5 * (loss_real + loss_fake)
    # 返回判别器损失
    return d_loss


# 定义对抗训练中的另一种判别器损失函数(vanilla 损失)
def vanilla_d_loss(logits_real, logits_fake):
    # 计算总的判别器损失,使用 softplus 函数
    d_loss = 0.5 * (
        torch.mean(torch.nn.functional.softplus(-logits_real))  # 计算真实样本的 softplus 损失
        + torch.mean(torch.nn.functional.softplus(logits_fake))   # 计算假样本的 softplus 损失
    )
    # 返回判别器损失
    return d_loss

.\cogview3-finetune\sat\sgm\modules\autoencoding\lpips\__init__.py

请提供需要添加注释的代码片段。

.\cogview3-finetune\sat\sgm\modules\autoencoding\regularizers\__init__.py

# 导入抽象方法的库
from abc import abstractmethod
# 导入任意类型和元组类型
from typing import Any, Tuple

# 导入 PyTorch 库
import torch
# 导入神经网络模块
import torch.nn as nn
# 导入功能模块
import torch.nn.functional as F

# 导入对角高斯分布类
from ....modules.distributions.distributions import DiagonalGaussianDistribution


# 定义抽象正则化器类,继承自 nn.Module
class AbstractRegularizer(nn.Module):
    # 初始化方法
    def __init__(self):
        super().__init__()  # 调用父类构造方法

    # 前向传播方法,接受一个张量,返回一个张量和字典
    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        raise NotImplementedError()  # 抛出未实现错误

    # 抽象方法,获取可训练的参数
    @abstractmethod
    def get_trainable_parameters(self) -> Any:
        raise NotImplementedError()  # 抛出未实现错误


# 定义对角高斯正则化器类,继承自 AbstractRegularizer
class DiagonalGaussianRegularizer(AbstractRegularizer):
    # 初始化方法,接受一个布尔参数,默认值为 True
    def __init__(self, sample: bool = True):
        super().__init__()  # 调用父类构造方法
        self.sample = sample  # 存储是否采样的参数

    # 获取可训练参数的方法,返回一个生成器
    def get_trainable_parameters(self) -> Any:
        yield from ()  # 生成器为空,表示没有可训练参数

    # 前向传播方法,接受一个张量,返回一个张量和字典
    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        log = dict()  # 创建一个空字典,用于记录日志
        posterior = DiagonalGaussianDistribution(z)  # 创建对角高斯分布实例
        if self.sample:  # 如果需要采样
            z = posterior.sample()  # 从后验分布中采样
        else:  # 如果不需要采样
            z = posterior.mode()  # 取后验分布的众数
        kl_loss = posterior.kl()  # 计算 Kullback-Leibler 散度损失
        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]  # 计算均值损失
        log["kl_loss"] = kl_loss  # 将 KL 损失记录到日志中
        return z, log  # 返回采样或众数以及日志


# 定义测量困惑度的函数,接受预测的索引和质心数量
def measure_perplexity(predicted_indices, num_centroids):
    # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
    # 评估聚类的困惑度。当困惑度等于 num_embeddings 时,所有聚类被完全均匀使用
    encodings = (
        F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)  # 将预测索引转换为独热编码
    )
    avg_probs = encodings.mean(0)  # 计算每个质心的平均概率
    perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()  # 计算困惑度
    cluster_use = torch.sum(avg_probs > 0)  # 计算被使用的聚类数量
    return perplexity, cluster_use  # 返回困惑度和聚类使用情况

.\cogview3-finetune\sat\sgm\modules\autoencoding\__init__.py

请提供需要注释的代码,我将帮助你添加详细的注释。

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\denoiser.py

# 导入所需的类型定义,便于后续使用
from typing import Dict, Union

# 导入 PyTorch 及其神经网络模块
import torch
import torch.nn as nn

# 从上层目录导入实用工具函数
from ...util import append_dims, instantiate_from_config


# 定义去噪器类,继承自 nn.Module
class Denoiser(nn.Module):
    # 初始化方法,接受加权配置和缩放配置
    def __init__(self, weighting_config, scaling_config):
        # 调用父类构造函数
        super().__init__()

        # 根据加权配置实例化加权模块
        self.weighting = instantiate_from_config(weighting_config)
        # 根据缩放配置实例化缩放模块
        self.scaling = instantiate_from_config(scaling_config)

    # 可能对 sigma 进行量化的占位符方法
    def possibly_quantize_sigma(self, sigma):
        return sigma  # 返回未修改的 sigma

    # 可能对噪声 c_noise 进行量化的占位符方法
    def possibly_quantize_c_noise(self, c_noise):
        return c_noise  # 返回未修改的 c_noise

    # 计算加权后的 sigma 值
    def w(self, sigma):
        return self.weighting(sigma)  # 返回加权后的 sigma

    # 前向传播方法,定义模型如何处理输入
    def forward(
        self,
        network: nn.Module,  # 网络模型
        input: torch.Tensor,  # 输入张量
        sigma: torch.Tensor,  # sigma 张量
        cond: Dict,  # 条件字典
        **additional_model_inputs,  # 其他模型输入
    ) -> torch.Tensor:
        # 可能对 sigma 进行量化
        sigma = self.possibly_quantize_sigma(sigma)
        # 获取 sigma 的形状
        sigma_shape = sigma.shape
        # 调整 sigma 的维度,以匹配输入的维度
        sigma = append_dims(sigma, input.ndim)
        # 使用缩放模块处理 sigma 和额外模型输入,获取多个输出
        c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs)
        # 可能对噪声 c_noise 进行量化,并调整形状
        c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
        # 返回经过网络处理的结果,结合输入和噪声
        return (
            network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
            + input * c_skip
        )


# 定义离散去噪器类,继承自 Denoiser
class DiscreteDenoiser(Denoiser):
    # 初始化方法,接受多个配置参数
    def __init__(
        self,
        weighting_config,  # 加权配置
        scaling_config,  # 缩放配置
        num_idx,  # 索引数量
        discretization_config,  # 离散化配置
        do_append_zero=False,  # 是否附加零
        quantize_c_noise=True,  # 是否量化 c_noise
        flip=True,  # 是否翻转
    ):
        # 调用父类构造函数
        super().__init__(weighting_config, scaling_config)
        # 根据离散化配置实例化 sigma
        sigmas = instantiate_from_config(discretization_config)(
            num_idx, do_append_zero=do_append_zero, flip=flip
        )
        # 保存 sigma 值
        self.sigmas = sigmas
        # self.register_buffer("sigmas", sigmas)  # (可选)注册一个持久化的缓冲区
        # 设置是否量化 c_noise
        self.quantize_c_noise = quantize_c_noise

    # 将 sigma 转换为索引的方法
    def sigma_to_idx(self, sigma):
        # 计算 sigma 与 sigma 列表的距离
        dists = sigma - self.sigmas.to(sigma.device)[:, None]
        # 返回距离最小的索引
        return dists.abs().argmin(dim=0).view(sigma.shape)

    # 将索引转换为 sigma 的方法
    def idx_to_sigma(self, idx):
        return self.sigmas.to(idx.device)[idx]  # 根据索引返回对应的 sigma

    # 可能对 sigma 进行量化的重写方法
    def possibly_quantize_sigma(self, sigma):
        return self.idx_to_sigma(self.sigma_to_idx(sigma))  # 返回量化后的 sigma

    # 可能对噪声 c_noise 进行量化的重写方法
    def possibly_quantize_c_noise(self, c_noise):
        if self.quantize_c_noise:  # 如果选择量化 c_noise
            return self.sigma_to_idx(c_noise)  # 返回量化后的索引
        else:
            return c_noise  # 返回未修改的 c_noise

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\denoiser_scaling.py

# 从 abc 模块导入抽象基类和抽象方法
from abc import ABC, abstractmethod
# 从 typing 模块导入任意类型和元组
from typing import Any, Tuple

# 导入 PyTorch 库
import torch


# 定义一个抽象基类 DenoiserScaling
class DenoiserScaling(ABC):
    # 定义一个抽象方法 __call__
    @abstractmethod
    def __call__(
        self, sigma: torch.Tensor, **additional_model_inputs
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 抽象方法没有具体实现
        pass


# 定义 EDMScaling 类
class EDMScaling:
    # 初始化方法,接受一个 sigma 数据,默认值为 0.5
    def __init__(self, sigma_data: float = 0.5):
        # 将 sigma_data 保存为实例变量
        self.sigma_data = sigma_data

    # 定义一个可调用方法 __call__
    def __call__(
        self, sigma: torch.Tensor, **additional_model_inputs
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 计算 c_skip 的值
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        # 计算 c_out 的值
        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
        # 计算 c_in 的值
        c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
        # 计算 c_noise 的值
        c_noise = 0.25 * sigma.log()
        # 返回四个计算结果
        return c_skip, c_out, c_in, c_noise


# 定义 EpsScaling 类
class EpsScaling:
    # 定义一个可调用方法 __call__
    def __call__(
        self, sigma: torch.Tensor, **additional_model_inputs
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 创建与 sigma 相同形状的全 1 张量作为 c_skip
        c_skip = torch.ones_like(sigma, device=sigma.device)
        # 计算 c_out 的值为 -sigma
        c_out = -sigma
        # 计算 c_in 的值
        c_in = 1 / (sigma**2 + 1.0) ** 0.5
        # 复制 sigma 作为 c_noise
        c_noise = sigma.clone()
        # 返回四个计算结果
        return c_skip, c_out, c_in, c_noise


# 定义 VScaling 类
class VScaling:
    # 定义一个可调用方法 __call__
    def __call__(
        self, sigma: torch.Tensor, **additional_model_inputs
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 计算 c_skip 的值
        c_skip = 1.0 / (sigma**2 + 1.0)
        # 计算 c_out 的值
        c_out = -sigma / (sigma**2 + 1.0) ** 0.5
        # 计算 c_in 的值
        c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
        # 复制 sigma 作为 c_noise
        c_noise = sigma.clone()
        # 返回四个计算结果
        return c_skip, c_out, c_in, c_noise


# 定义 VScalingWithEDMcNoise 类,继承自 DenoiserScaling
class VScalingWithEDMcNoise(DenoiserScaling):
    # 定义一个可调用方法 __call__
    def __call__(
        self, sigma: torch.Tensor, **additional_model_inputs
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 计算 c_skip 的值
        c_skip = 1.0 / (sigma**2 + 1.0)
        # 计算 c_out 的值
        c_out = -sigma / (sigma**2 + 1.0) ** 0.5
        # 计算 c_in 的值
        c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
        # 计算 c_noise 的值
        c_noise = 0.25 * sigma.log()
        # 返回四个计算结果
        return c_skip, c_out, c_in, c_noise

# 定义 ZeroSNRScaling 类,类似于 VScaling
class ZeroSNRScaling: # similar to VScaling
    # 定义一个可调用方法 __call__
    def __call__(
        self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 将 alphas_cumprod_sqrt 作为 c_skip
        c_skip = alphas_cumprod_sqrt
        # 计算 c_out 的值
        c_out = - (1 - alphas_cumprod_sqrt**2) ** 0.5
        # 创建与 alphas_cumprod_sqrt 相同形状的全 1 张量作为 c_in
        c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device)
        # 复制额外输入的 'idx' 作为 c_noise
        c_noise = additional_model_inputs['idx'].clone()
        # 返回四个计算结果
        return c_skip, c_out, c_in, c_noise

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\denoiser_weighting.py

# 导入 PyTorch 库
import torch


# 定义一个单位加权类
class UnitWeighting:
    # 定义可调用方法,接收参数 sigma
    def __call__(self, sigma):
        # 返回与 sigma 形状相同的全 1 张量,设备与 sigma 相同
        return torch.ones_like(sigma, device=sigma.device)


# 定义 EDM 加权类
class EDMWeighting:
    # 初始化方法,设置 sigma_data 的默认值为 0.5
    def __init__(self, sigma_data=0.5):
        # 将传入的 sigma_data 存储为实例变量
        self.sigma_data = sigma_data

    # 定义可调用方法,接收参数 sigma
    def __call__(self, sigma):
        # 根据公式计算加权值并返回
        return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2


# 定义 V 加权类,继承自 EDMWeighting
class VWeighting(EDMWeighting):
    # 初始化方法
    def __init__(self):
        # 调用父类初始化方法,设置 sigma_data 为 1.0
        super().__init__(sigma_data=1.0)


# 定义 Eps 加权类
class EpsWeighting:
    # 定义可调用方法,接收参数 sigma
    def __call__(self, sigma):
        # 返回 sigma 的 -2 次幂
        return sigma**-2.0

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\discretizer.py

# 从 abc 模块导入抽象方法,用于定义抽象基类
from abc import abstractmethod
# 从 functools 模块导入 partial,用于创建偏函数
from functools import partial

# 导入 numpy 库,主要用于数值计算
import numpy as np
# 导入 torch 库,主要用于深度学习和张量操作
import torch

# 从自定义模块中导入 make_beta_schedule 函数,用于生成 beta 时间表
from ...modules.diffusionmodules.util import make_beta_schedule
# 从自定义模块中导入 append_zero 函数,用于处理数组
from ...util import append_zero


# 定义函数 generate_roughly_equally_spaced_steps,用于生成大致均匀间隔的步骤
def generate_roughly_equally_spaced_steps(
    num_substeps: int, max_step: int
) -> np.ndarray:
    # 生成从 max_step-1 到 0 的均匀间隔数组,包含 num_substeps 个元素,并将结果反转
    return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]

# 定义函数 sub_generate_roughly_equally_spaced_steps,用于生成两个子步骤的均匀间隔
def sub_generate_roughly_equally_spaced_steps(
    num_substeps_1: int, num_substeps_2: int, max_step: int
) -> np.ndarray:
    # 生成第二组子步骤的均匀间隔
    substeps_2 = np.linspace(max_step - 1, 0, num_substeps_2, endpoint=False).astype(int)[::-1]
    # 生成第一组子步骤的均匀间隔
    substeps_1 = np.linspace(num_substeps_2 - 1, 0, num_substeps_1, endpoint=False).astype(int)[::-1]
    # 返回根据第一组子步骤索引获取第二组子步骤的数组
    return substeps_2[substeps_1]

# 定义离散化的抽象基类 Discretization
class Discretization:
    # 定义可调用方法,接受 n、do_append_zero、device 和 flip 参数
    def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
        # 调用 get_sigmas 方法获取 sigma 值
        sigmas = self.get_sigmas(n, device=device)
        # 如果 do_append_zero 为真,向 sigmas 添加零
        sigmas = append_zero(sigmas) if do_append_zero else sigmas
        # 根据 flip 参数决定是否反转 sigmas
        return sigmas if not flip else torch.flip(sigmas, (0,))

    # 定义抽象方法 get_sigmas,必须在子类中实现
    @abstractmethod
    def get_sigmas(self, n, device):
        pass


# 定义 EDMDiscretization 类,继承自 Discretization
class EDMDiscretization(Discretization):
    # 初始化方法,设置 sigma 的最小值、最大值和 rho 值
    def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.rho = rho

    # 实现抽象方法 get_sigmas,计算 sigma 值
    def get_sigmas(self, n, device="cpu"):
        # 生成从 0 到 1 的等间隔张量
        ramp = torch.linspace(0, 1, n, device=device)
        # 计算最小和最大 rho 的倒数
        min_inv_rho = self.sigma_min ** (1 / self.rho)
        max_inv_rho = self.sigma_max ** (1 / self.rho)
        # 根据公式计算 sigmas
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
        # 返回计算得到的 sigmas
        return sigmas


# 定义 LegacyDDPMDiscretization 类,继承自 Discretization
class LegacyDDPMDiscretization(Discretization):
    # 初始化方法,设置线性开始、结束和时间步数
    def __init__(
        self,
        linear_start=0.00085,
        linear_end=0.0120,
        num_timesteps=1000,
    ):
        # 调用父类初始化方法
        super().__init__()
        self.num_timesteps = num_timesteps
        # 使用 make_beta_schedule 生成 beta 时间表
        betas = make_beta_schedule(
            "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
        )
        # 计算 alphas
        alphas = 1.0 - betas
        # 计算累积的 alphas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        # 创建将 numpy 数组转换为 torch 张量的偏函数
        self.to_torch = partial(torch.tensor, dtype=torch.float32)

    # 实现抽象方法 get_sigmas,根据 n 计算 sigma 值
    def get_sigmas(self, n, device="cpu"):
        # 如果 n 小于总时间步数
        if n < self.num_timesteps:
            # 生成大致均匀间隔的时间步
            timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
            # 获取对应的累积 alphas
            alphas_cumprod = self.alphas_cumprod[timesteps]
        # 如果 n 等于总时间步数
        elif n == self.num_timesteps:
            alphas_cumprod = self.alphas_cumprod
        # 如果 n 大于总时间步数,抛出异常
        else:
            raise ValueError

        # 创建将 numpy 数组转换为 torch 张量的偏函数,指定设备
        to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
        # 计算 sigmas 值
        sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
        # 返回反转的 sigmas
        return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029

# 定义 ZeroSNRDDPMDiscretization 类,继承自 Discretization
class ZeroSNRDDPMDiscretization(Discretization):
    # 初始化方法,设置线性开始、结束、时间步数和 shift_scale
    def __init__(
        self,
        linear_start=0.00085,
        linear_end=0.0120,
        num_timesteps=1000,
        shift_scale=1.,
    # 初始化父类
        ):
            super().__init__()
            # 设置时间步数
            self.num_timesteps = num_timesteps
            # 生成线性调度的 beta 值
            betas = make_beta_schedule(
                "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
            )
            # 计算 alpha 值
            alphas = 1.0 - betas
            # 计算累积的 alpha 值
            self.alphas_cumprod = np.cumprod(alphas, axis=0)
            # 将数据转换为 torch 张量的函数
            self.to_torch = partial(torch.tensor, dtype=torch.float32)
    
            # 对累积的 alpha 值进行缩放
            self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1-shift_scale) * self.alphas_cumprod)
    
        # 获取 sigma 值的方法
        def get_sigmas(self, n, device="cpu", return_idx=False):
            # 判断请求的时间步数是否小于总时间步数
            if n < self.num_timesteps:
                # 生成等间距的时间步
                timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
                # 获取对应的累积 alpha 值
                alphas_cumprod = self.alphas_cumprod[timesteps]
            # 判断请求的时间步数是否等于总时间步数
            elif n == self.num_timesteps:
                alphas_cumprod = self.alphas_cumprod
            # 如果超出范围,抛出异常
            else:
                raise ValueError
    
            # 将 alpha 值转换为 torch 张量
            to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
            alphas_cumprod = to_torch(alphas_cumprod)
            # 计算累积 alpha 值的平方根
            alphas_cumprod_sqrt = alphas_cumprod.sqrt()
            # 克隆初始和最终的 alpha 值
            alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone()
            alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
    
            # 对平方根进行归一化处理
            alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
            alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T)
    
            # 根据返回标志返回结果
            if return_idx:
                return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps
            else:
                # 返回反转的平方根 alpha 值
                return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99
            
        # 使对象可调用的方法
        def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False):
            # 根据返回标志调用获取 sigma 值的方法
            if return_idx:
                sigmas, idx = self.get_sigmas(n, device=device, return_idx=True)
                sigmas = append_zero(sigmas) if do_append_zero else sigmas
                # 根据 flip 标志返回结果
                return (sigmas, idx) if not flip else (torch.flip(sigmas, (0,)), torch.flip(idx, (0,)))
            else:
                # 获取 sigma 值并处理
                sigmas = self.get_sigmas(n, device=device)
                sigmas = append_zero(sigmas) if do_append_zero else sigmas
                # 根据 flip 标志返回结果
                return sigmas if not flip else torch.flip(sigmas, (0,))

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\dit.py

# 从 omegaconf 导入 DictConfig 类,用于配置管理
from omegaconf import DictConfig
# 从 functools 导入 partial 函数,用于偏函数应用
from functools import partial
# 从 einops 导入 rearrange 函数,用于重排张量
from einops import rearrange
# 导入 numpy 库,用于数值计算
import numpy as np

# 导入 PyTorch 库
import torch
# 从 torch 导入 nn 模块,包含神经网络构建相关的工具
from torch import nn
# 导入 PyTorch 的分布式训练模块
import torch.distributed

# 从 sat.model.base_model 导入 BaseModel 类,作为模型基类
from sat.model.base_model import BaseModel
# 从 sat.model.mixins 导入 BaseMixin 类,用于混入模型功能
from sat.model.mixins import BaseMixin
# 从 sat.ops.layernorm 导入 LayerNorm 类,用于层归一化
from sat.ops.layernorm import LayerNorm
# 从 sat.transformer_defaults 导入默认的 hooks 和注意力函数
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
# 从 sat.mpu.utils 导入用于张量分割的工具
from sat.mpu.utils import split_tensor_along_last_dim
# 从 sgm.util 导入一些工具函数
from sgm.util import (
    disabled_train,          # 禁用训练的装饰器
    instantiate_from_config, # 从配置实例化对象
)

# 从 sgm.modules.diffusionmodules.openaimodel 导入时间步类
from sgm.modules.diffusionmodules.openaimodel import Timestep
# 从 sgm.modules.diffusionmodules.util 导入卷积、线性层和时间步嵌入等工具
from sgm.modules.diffusionmodules.util import (
    conv_nd,                # 多维卷积函数
    linear,                 # 线性变换函数
    timestep_embedding,     # 时间步嵌入函数
)

# 定义调制函数,接受输入张量、偏移量和缩放因子
def modulate(x, shift, scale):
    # 根据缩放因子和偏移量对输入张量进行调制
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

# 定义解patch化函数,将补丁张量恢复为图像张量
def unpatchify(x, channels, patch_size, height, width):
    # 使用 rearrange 函数将补丁张量重排为图像张量
    x = rearrange(
        x,
        "b (h w) (c p1 p2) -> b c (h p1) (w p2)",  # 指定输入和输出的张量维度
        h=height // patch_size,  # 计算行数
        w=width // patch_size,    # 计算列数
        p1=patch_size,            # 每个补丁的高度
        p2=patch_size,            # 每个补丁的宽度
    )
    # 返回解patch化后的图像张量
    return x

# 定义图像补丁嵌入混入类,继承自 BaseMixin
class ImagePatchEmbeddingMixin(BaseMixin):
    # 初始化函数,设置输入通道、隐藏层大小、补丁大小等属性
    def __init__(
            self,
            in_channels,
            hidden_size,
            patch_size,
            text_hidden_size=None,
            do_rearrange=True,
    ):
        # 调用父类初始化
        super().__init__()
        # 设置输入通道数
        self.in_channels = in_channels
        # 设置隐藏层大小
        self.hidden_size = hidden_size
        # 设置补丁大小
        self.patch_size = patch_size
        # 设置文本隐藏层大小(如果有的话)
        self.text_hidden_size = text_hidden_size
        # 设置是否重排张量的标志
        self.do_rearrange = do_rearrange

        # 初始化线性层,将补丁的通道数映射到隐藏层大小
        self.proj = nn.Linear(in_channels * patch_size ** 2, hidden_size)
        # 如果提供了文本隐藏层大小,则初始化相应的线性层
        if text_hidden_size is not None:
            self.text_proj = nn.Linear(text_hidden_size, hidden_size)

    # 定义词嵌入前向传播函数,接受输入ID、图像和编码器输出
    def word_embedding_forward(self, input_ids, images, encoder_outputs, **kwargs):
        # images: B x C x H x W,表示批量图像的形状
        # 如果需要重排图像张量
        if self.do_rearrange:
            # 使用 rearrange 函数将图像重排为补丁格式
            patches_images = rearrange(
                images, "b c (h p1) (w p2) -> b (h w) (c p1 p2)",  # 指定输入和输出的张量维度
                p1=self.patch_size,  # 每个补丁的高度
                p2=self.patch_size,  # 每个补丁的宽度
            )
        else:
            # 否则直接使用原始图像张量
            patches_images = images
        # 通过线性层对补丁图像进行映射
        emb = self.proj(patches_images)

        # 如果有文本隐藏层大小
        if self.text_hidden_size is not None:
            # 对编码器输出进行线性映射
            text_emb = self.text_proj(encoder_outputs)
            # 将文本嵌入与图像嵌入在维度1上进行连接
            emb = torch.cat([text_emb, emb], dim=1)

        # 返回最终的嵌入结果
        return emb

    # 定义重新初始化函数
    def reinit(self, parent_model=None):
        # 获取线性层的权重
        w = self.proj.weight.data
        # 使用 Xavier 均匀分布初始化权重
        nn.init.xavier_uniform_(self.proj.weight)
        # 将偏置初始化为零
        nn.init.constant_(self.proj.bias, 0)
        # 删除 transformer 的词嵌入
        del self.transformer.word_embeddings

# 定义获取 2D 正弦余弦位置嵌入的函数
def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    # 创建一个表示网格高度的数组
    grid_h = np.arange(grid_height, dtype=np.float32)
    # 创建一个表示网格宽度的数组
    grid_w = np.arange(grid_width, dtype=np.float32)
    # 生成网格坐标,宽度优先
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    # 将网格数组堆叠到一个新的维度
    grid = np.stack(grid, axis=0)

    # 将网格重塑为 [2, 1, grid_height, grid_width] 形状
    grid = grid.reshape([2, 1, grid_height, grid_width])
    # 从网格生成二维正弦余弦位置嵌入
        pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
        # 如果需要分类标记且额外的标记数大于0,则进行处理
        if cls_token and extra_tokens > 0:
            # 在位置嵌入前添加额外的零嵌入,形成新的位置嵌入
            pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
        # 返回最终的位置嵌入
        return pos_embed
# 从网格生成二维正弦余弦位置嵌入
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    # 确保嵌入维度是偶数
    assert embed_dim % 2 == 0

    # 使用一半的维度编码网格高度
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    # 使用一半的维度编码网格宽度
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    # 将高度和宽度的嵌入合并
    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    # 返回最终的嵌入
    return emb


# 从网格生成一维正弦余弦位置嵌入
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: 每个位置的输出维度
    pos: 需要编码的位置列表:大小 (M,)
    out: (M, D)
    """
    # 确保嵌入维度是偶数
    assert embed_dim % 2 == 0
    # 生成 omega 数组用于位置编码
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.0
    # 计算频率,得到 (D/2,)
    omega = 1.0 / 10000 ** omega  # (D/2,)

    # 将位置调整为一维
    pos = pos.reshape(-1)  # (M,)
    # 计算位置与频率的外积
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    # 计算正弦和余弦嵌入
    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    # 将正弦和余弦嵌入合并
    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    # 返回最终的嵌入
    return emb


# 位置嵌入混合类
class PositionEmbeddingMixin(BaseMixin):
    def __init__(
            self,
            max_height,
            max_width,
            hidden_size,
            text_length=0,
            block_size=16,
            **kwargs,
    ):
        # 初始化父类
        super().__init__()
        # 设置最大高度
        self.max_height = max_height
        # 设置最大宽度
        self.max_width = max_width
        # 设置隐藏层大小
        self.hidden_size = hidden_size
        # 设置文本长度
        self.text_length = text_length
        # 设置块大小
        self.block_size = block_size
        # 初始化图像位置嵌入参数
        self.image_pos_embedding = nn.Parameter(
            torch.zeros(self.max_height, self.max_width, hidden_size), requires_grad=False
        )

    # 前向传播计算位置嵌入
    def position_embedding_forward(self, position_ids, target_size, **kwargs):
        ret = []
        # 遍历目标大小
        for h, w in target_size:
            # 将高度和宽度除以块大小
            h, w = h // self.block_size, w // self.block_size
            # 获取图像位置嵌入并重塑
            image_pos_embed = self.image_pos_embedding[:h, :w].reshape(h * w, -1)
            # 连接文本嵌入与图像嵌入
            pos_embed = torch.cat(
                [
                    torch.zeros(
                        (self.text_length, self.hidden_size),
                        dtype=image_pos_embed.dtype,
                        device=image_pos_embed.device,
                    ),
                    image_pos_embed,
                ],
                dim=0,
            )
            # 添加到结果列表中
            ret.append(pos_embed[None, ...])
        # 合并所有位置嵌入
        return torch.cat(ret, dim=0)

    # 重新初始化位置嵌入
    def reinit(self, parent_model=None):
        # 删除当前位置嵌入
        del self.transformer.position_embeddings
        # 获取新的二维正弦余弦位置嵌入
        pos_embed = get_2d_sincos_pos_embed(self.image_pos_embedding.shape[-1], self.max_height, self.max_width)
        # 重塑位置嵌入为二维形状
        pos_embed = pos_embed.reshape(self.max_height, self.max_width, -1)

        # 复制新的位置嵌入数据
        self.image_pos_embedding.data.copy_(torch.from_numpy(pos_embed).float())


# 最终层混合类
class FinalLayerMixin(BaseMixin):
    def __init__(
            self,
            hidden_size,
            time_embed_dim,
            patch_size,
            block_size,
            out_channels,
            elementwise_affine=False,
            eps=1e-6,
            do_unpatchify=True,
    ):
        # 调用父类构造函数进行初始化
        super().__init__()
        # 设置隐藏层大小
        self.hidden_size = hidden_size
        # 设置每个补丁的大小
        self.patch_size = patch_size
        # 设置块的大小
        self.block_size = block_size
        # 设置输出通道数
        self.out_channels = out_channels
        # 确定是否进行去补丁处理
        self.do_unpatchify = do_unpatchify

        # 初始化最终的层归一化,带有可学习的参数
        self.norm_final = nn.LayerNorm(
            hidden_size,
            elementwise_affine=elementwise_affine,
            eps=eps,
        )
        # 创建一个包含SiLU激活和线性层的序列
        self.adaln = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_embed_dim, 2 * hidden_size),
        )
        # 初始化线性层以将隐藏状态映射到输出通道
        self.linear = nn.Linear(hidden_size, out_channels * patch_size ** 2)

    def final_forward(self, logits, emb, text_length, target_size=None, **kwargs):
        # 截取logits以获取文本长度后的部分
        x = logits[:, text_length:]
        # 使用adaln模块对嵌入进行变换并拆分为偏移和缩放
        shift, scale = self.adaln(emb).chunk(2, dim=1)
        # 对x进行归一化并应用偏移和缩放
        x = modulate(self.norm_final(x), shift, scale)
        # 通过线性层获得最终输出
        x = self.linear(x)

        # 如果需要去补丁处理
        if self.do_unpatchify:
            # 从目标大小中提取高度和宽度
            target_height, target_width = target_size[0]
            # 断言目标大小必须能被块大小整除
            assert (
                    target_height % self.block_size == 0 and target_width % self.block_size == 0
            ), "target size must be divisible by block size"
            # 计算输出高度和宽度
            out_height, out_width = (
                target_height // self.block_size * self.patch_size,
                target_width // self.block_size * self.patch_size,
            )
            # 进行去补丁处理,恢复原图
            x = unpatchify(
                x, channels=self.out_channels, patch_size=self.patch_size, height=out_height, width=out_width
            )
        # 返回最终输出
        return x

    def reinit(self, parent_model=None):
        # 使用Xavier均匀分布初始化线性层权重
        nn.init.xavier_uniform_(self.linear.weight)
        # 将线性层偏置初始化为0
        nn.init.constant_(self.linear.bias, 0)
# 定义一个混合类 AdalnAttentionMixin,继承自 BaseMixin
class AdalnAttentionMixin(BaseMixin):
    # 初始化函数,接受多个参数以设置模型的各项属性
    def __init__(
            self,
            hidden_size,  # 隐藏层大小
            num_layers,  # 层数
            time_embed_dim,  # 时间嵌入维度
            qk_ln=True,  # 是否使用查询和键的层归一化
            hidden_size_head=None,  # 头部的隐藏层大小
            elementwise_affine=False,  # 是否使用逐元素仿射变换
            eps=1e-6,  # 用于层归一化的平滑项
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 创建一个包含多个顺序模块的模块列表,每个模块由 SiLU 激活函数和线性层组成
        self.adaln_modules = nn.ModuleList(
            [nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
        )

        # 记录是否使用查询和键的层归一化
        self.qk_ln = qk_ln
        # 如果使用层归一化,则为查询和键分别创建模块列表
        if qk_ln:
            # 创建用于查询的层归一化模块列表
            self.query_layernorms = nn.ModuleList(
                [
                    LayerNorm(hidden_size_head, elementwise_affine=elementwise_affine, eps=eps)
                    for _ in range(num_layers)  # 为每一层创建一个层归一化模块
                ]
            )
            # 创建用于键的层归一化模块列表
            self.key_layernorms = nn.ModuleList(
                [
                    LayerNorm(hidden_size_head, elementwise_affine=elementwise_affine, eps=eps)
                    for _ in range(num_layers)  # 为每一层创建一个层归一化模块
                ]
            )

    # 定义前向传播的方法,接收隐藏状态、掩码、文本长度等参数
    def layer_forward(
            self,
            hidden_states,  # 当前层的隐藏状态
            mask,  # 掩码,用于忽略特定位置
            text_length,  # 文本的长度
            layer_id,  # 当前层的索引
            emb,  # 嵌入表示
            *args,  # 可变参数
            **kwargs,  # 关键字参数
    # 定义一个方法的结尾,接受必要的参数
        ):
            # 获取指定层的 Transformer 层
            layer = self.transformer.layers[layer_id]
            # 获取与该层对应的自适应层归一化模块
            adaln_module = self.adaln_modules[layer_id]
    
            # 从自适应层归一化模块中处理输入并将结果分块为 12 个部分
            (
                shift_msa_img,
                scale_msa_img,
                gate_msa_img,
                shift_mlp_img,
                scale_mlp_img,
                gate_mlp_img,
                shift_msa_txt,
                scale_msa_txt,
                gate_msa_txt,
                shift_mlp_txt,
                scale_mlp_txt,
                gate_mlp_txt,
            ) = adaln_module(emb).chunk(12, dim=1)
            # 扩展门控张量的维度,以便后续处理
            gate_msa_img, gate_mlp_img, gate_msa_txt, gate_mlp_txt = (
                gate_msa_img.unsqueeze(1),
                gate_mlp_img.unsqueeze(1),
                gate_msa_txt.unsqueeze(1),
                gate_mlp_txt.unsqueeze(1),
            )
    
            # 对输入的隐藏状态进行层归一化处理
            attention_input = layer.input_layernorm(hidden_states)
    
            # 对文本输入进行调制,应用相应的偏移和缩放
            text_attention_input = modulate(attention_input[:, :text_length], shift_msa_txt, scale_msa_txt)
            # 对图像输入进行调制,应用相应的偏移和缩放
            image_attention_input = modulate(attention_input[:, text_length:], shift_msa_img, scale_msa_img)
            # 将文本和图像的注意力输入合并
            attention_input = torch.cat((text_attention_input, image_attention_input), dim=1)
    
            # 计算注意力输出,应用遮罩和层 ID
            attention_output = layer.attention(attention_input, mask, layer_id=layer_id, **kwargs)
            # 如果层归一化顺序是 "sandwich",则应用第三次层归一化
            if self.transformer.layernorm_order == "sandwich":
                attention_output = layer.third_layernorm(attention_output)
    
            # 将隐藏状态分为文本和图像部分
            text_hidden_states, image_hidden_states = hidden_states[:, :text_length], hidden_states[:, text_length:]
            # 将注意力输出分为文本和图像部分
            text_attention_output, image_attention_output = (
                attention_output[:, :text_length],
                attention_output[:, text_length:],
            )
            # 更新文本隐藏状态,加上文本注意力输出的加权
            text_hidden_states = text_hidden_states + gate_msa_txt * text_attention_output
            # 更新图像隐藏状态,加上图像注意力输出的加权
            image_hidden_states = image_hidden_states + gate_msa_img * image_attention_output
            # 合并更新后的隐藏状态
            hidden_states = torch.cat((text_hidden_states, image_hidden_states), dim=1)
    
            # 对合并后的隐藏状态进行后注意力层归一化
            mlp_input = layer.post_attention_layernorm(hidden_states)
    
            # 对文本输入进行调制,应用相应的偏移和缩放
            text_mlp_input = modulate(mlp_input[:, :text_length], shift_mlp_txt, scale_mlp_txt)
            # 对图像输入进行调制,应用相应的偏移和缩放
            image_mlp_input = modulate(mlp_input[:, text_length:], shift_mlp_img, scale_mlp_img)
            # 将文本和图像的 MLP 输入合并
            mlp_input = torch.cat((text_mlp_input, image_mlp_input), dim=1)
    
            # 计算 MLP 输出,应用层 ID
            mlp_output = layer.mlp(mlp_input, layer_id=layer_id, **kwargs)
            # 如果层归一化顺序是 "sandwich",则应用第四次层归一化
            if self.transformer.layernorm_order == "sandwich":
                mlp_output = layer.fourth_layernorm(mlp_output)
    
            # 将隐藏状态分为文本和图像部分
            text_hidden_states, image_hidden_states = hidden_states[:, :text_length], hidden_states[:, text_length:]
            # 将 MLP 输出分为文本和图像部分
            text_mlp_output, image_mlp_output = mlp_output[:, :text_length], mlp_output[:, text_length:]
            # 更新文本隐藏状态,加上文本 MLP 输出的加权
            text_hidden_states = text_hidden_states + gate_mlp_txt * text_mlp_output
            # 更新图像隐藏状态,加上图像 MLP 输出的加权
            image_hidden_states = image_hidden_states + gate_mlp_img * image_mlp_output
            # 合并更新后的隐藏状态
            hidden_states = torch.cat((text_hidden_states, image_hidden_states), dim=1)
    
            # 返回最终的隐藏状态
            return hidden_states
    # 定义注意力前向传播函数,接收隐藏状态、掩码、层ID及其他参数
    def attention_forward(self, hidden_states, mask, layer_id, **kwargs):
        # 获取指定层的注意力模块
        attention = self.transformer.layers[layer_id].attention

        # 默认的注意力计算函数
        attention_fn = attention_fn_default
        # 如果注意力模块有自定义的注意力函数,则使用它
        if "attention_fn" in attention.hooks:
            attention_fn = attention.hooks["attention_fn"]

        # 通过隐藏状态计算查询、键、值
        qkv = attention.query_key_value(hidden_states)
        # 将查询、键、值沿最后一个维度分离
        mixed_query_layer, mixed_key_layer, mixed_value_layer = split_tensor_along_last_dim(qkv, 3)

        # 根据训练状态选择是否应用 dropout
        dropout_fn = attention.attention_dropout if self.training else None

        # 转置查询、键、值以便于后续计算
        query_layer = attention._transpose_for_scores(mixed_query_layer)
        key_layer = attention._transpose_for_scores(mixed_key_layer)
        value_layer = attention._transpose_for_scores(mixed_value_layer)

        # 如果使用层归一化,应用于查询和键
        if self.qk_ln:
            query_layernorm = self.query_layernorms[layer_id]
            key_layernorm = self.key_layernorms[layer_id]
            query_layer = query_layernorm(query_layer)
            key_layer = key_layernorm(key_layer)

        # 计算上下文层,使用指定的注意力函数
        context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kwargs)
        # 调整上下文层的维度顺序
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # 创建新的上下文层形状,适配后续计算
        new_context_layer_shape = context_layer.size()[:-2] + (attention.hidden_size_per_partition,)
        # 重新调整上下文层的形状
        context_layer = context_layer.view(*new_context_layer_shape)

        # 通过全连接层计算输出
        output = attention.dense(context_layer)
        # 如果处于训练状态,应用输出的 dropout
        if self.training:
            output = attention.output_dropout(output)

        # 返回最终输出
        return output

    # 定义多层感知机的前向传播函数
    def mlp_forward(self, hidden_states, layer_id, **kwargs):
        # 获取指定层的多层感知机模块
        mlp = self.transformer.layers[layer_id].mlp

        # 通过全连接层将隐藏状态映射到更高维度
        intermediate_parallel = mlp.dense_h_to_4h(hidden_states)
        # 应用激活函数
        intermediate_parallel = mlp.activation_func(intermediate_parallel)
        # 将高维结果映射回原维度
        output = mlp.dense_4h_to_h(intermediate_parallel)

        # 如果处于训练状态,应用 dropout
        if self.training:
            output = mlp.dropout(output)

        # 返回最终输出
        return output

    # 定义重新初始化函数,接受可选的父模型参数
    def reinit(self, parent_model=None):
        # 遍历自适应层模块
        for layer in self.adaln_modules:
            # 将最后一层的权重初始化为 0
            nn.init.constant_(layer[-1].weight, 0)
            # 将最后一层的偏置初始化为 0
            nn.init.constant_(layer[-1].bias, 0)
# 定义一个字符串到数据类型的映射
str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}


# 定义一个扩散变换器类,继承自基础模型
class DiffusionTransformer(BaseModel):
    # 初始化方法,接受多个参数以配置模型
    def __init__(
            self,
            in_channels,  # 输入通道数
            out_channels,  # 输出通道数
            hidden_size,  # 隐藏层大小
            patch_size,  # 图像块大小
            num_layers,  # 层数
            num_attention_heads,  # 注意力头数
            text_length,  # 文本长度
            time_embed_dim=None,  # 时间嵌入维度,默认为 None
            num_classes=None,  # 类别数量,默认为 None
            adm_in_channels=None,  # 自适应输入通道,默认为 None
            modules={},  # 额外模块,默认为空字典
            dtype="fp32",  # 数据类型,默认为 fp32
            layernorm_order="pre",  # 层归一化顺序,默认为 "pre"
            elementwise_affine=False,  # 是否启用逐元素仿射,默认为 False
            parallel_output=True,  # 是否并行输出,默认为 True
            block_size=16,  # 块大小,默认为 16
            **kwargs,  # 其他关键字参数
    ):
        # 初始化基类
        super().__init__(**kwargs)

    # 前向传播方法,定义模型的前向计算
    def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
        # 将输入张量 x 转换为指定的数据类型
        x = x.to(self.dtype)
        # 获取时间步的嵌入表示
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
        # 对时间嵌入进行处理
        emb = self.time_embed(t_emb)

        # 确保 y 和 x 的批次大小一致
        assert y.shape[0] == x.shape[0]
        # 将标签嵌入与时间嵌入相加
        emb = emb + self.label_emb(y)
        # 创建输入 ID、位置 ID 和注意力掩码,均初始化为 1
        input_ids = position_ids = attention_mask = torch.ones((1, 1)).to(x.dtype)

        # 调用基类的前向方法,传入多个参数以计算输出
        output = super().forward(
            images=x,  # 输入图像
            emb=emb,  # 嵌入表示
            encoder_outputs=context,  # 编码器输出
            text_length=self.text_length,  # 文本长度
            input_ids=input_ids,  # 输入 ID
            position_ids=position_ids,  # 位置 ID
            attention_mask=attention_mask,  # 注意力掩码
            **kwargs,  # 其他关键字参数
        )[0]  # 获取输出的第一个元素

        # 返回模型的输出
        return output
posted @ 2024-10-23 09:19  绝不原创的飞龙  阅读(15)  评论(0编辑  收藏  举报