LLM大模型: diffusion transformer Dit原理和核心代码

  现阶段,主流文生图的思路就是DDPM了:先随机生成N~(0,1)的噪声图,然后逐步denoise,迭代1000次左右得到text指定的图;其中最核心的莫过于denoise时生成的noise图片:每次需要根据输入时间t、文本text和noise latent生成合适的noise图片。之前介绍了unet的方式生成noise图片,存在一些缺陷:

  • unet使用cnn提取图像特征,提取的是局部特征,缺乏全局信息;比如要生成一张全景城市图片,其中左边的天空非常蓝,右边有一座高楼。如果想生成一张完整的图像,需要模型了解图片的左边、右边,前面后面之间的关系,但UNet由于卷积局限,可能无法有效捕捉这种大范围的关联【cnn要想捕获全局信息,需要比较大的感受野,所以要堆叠几十层才行;而vision transformer的attention机制,在第一层就能让position 0 的token知道其他patch的信息(实际训练时大约5层效果接近cnn 30层);nlp核心之一就是当前token的context,cv核心之一就是感受野】。
  • unet没有self-attention机制,无法描述局部特征之间的远近关系;比如模型需要辨认不同建筑的风格、街道的布局,甚至物体之间的语义关系(如汽车在道路上行驶,行人在走道上行走),如果不依赖大范围的上下文patch和图像结构信息,UNet难以捕捉patch之间的关系,无法生成高度复杂的细节
  • 对于高分辨率的图片:unet无法并行,效率低;cnn的pooling操作降低了分辨率,可能导致图片细节丢失;

  large language model最初是在nlp领域火了,最核心的就是attention机制啦通过attention找到token之间的距离,根据距离远近确定V值,由此提取token正确的特征,得到靠谱的embedding representation!这个attention机制能不能用在图像上了?

  1、vison transformer,简称vit;如下图所示:

  • 原始图片被切分成N个小块,比如28*28被切成7个4*4的patch;每个path拉直成1*16的向量,也就是flattened。这里的patch,类似于nlp任务的token
  • path乘以矩阵做linear projection
  • 每个path加上position embedding后,进入标准的transformer的encoder,通过attention机制找到图片内部元素之间的远近关系
  • 进入transformer encoder之前,有个0位置的position,这个地方叫做extra learnable [class] embedding,也就是0号位置旁边的星号embedding!由于从1号position开始的patch 包含的都是图片的局部信息,所以这里需要0号position的embedding表示整个图片的全局信息,类似于nlp任务的CLS token【0号位置的token通过attention机制整合了其他所有patch的信息,后续做下游任务时直接用0号token的emnedding即可,不需要再把其他token的embedding相加了】!推理时,每张图片都会使用相同的、经过训练后的 class token 作为初始输入。这是因为 class token 是一个可学习的嵌入向量,经过大量训练后,它能够学会如何总结和提取整张图片的信息。在推理阶段,虽然初始的 class token 是一样的,但在 Transformer 经过多层自注意力机制处理后,每张图片的 class token 会根据该图片的具体内容更新,从而反映出不同图片的独特特征
  • 进过encoder后就得到了整个图片的embedding,然后通过MLP就可以做分类任务了!
  • 当然,transformer也有缺陷:需要大量的训练数据!
  • conv本质也是attention,卷积核kernel就是K和Q向量!

  

   至此,不论是nlp,还是cv图片,都可以用transformer架构统一了,换句话说,都可以通过attention机制取得不错的结果了!

   2、图片的embedding representation有了, 下一个核心的任务就是生成noise图片啦!之前的unet既然有那么多缺陷,怎么用transformer架构解决了?https://arxiv.org/pdf/2212.09748 提供了思路,可以看成是transformer的encoder,网络架构如下:

  

  • noise latent,也就是有noise的中间过程图片,需要基于这些图片生成noise,具体做法就是patchify:使用上面第1步骤的Vit方法,把noise latent转成embedding representation;注意:因为condition要精准控制image的细节,所以每个patch都会转成embedding,Dit中称之为input tokens;所有的patch都会被拉直,也就是flatten;image做token化!
  • 因为需要按照用户的prompt、多次迭代生成image,所以prompt和time信息也是必须的(这两合计称为conditioning),这些信息通过常规的linear projection或MLP转成embedding即可,接下来的核心问题就是,怎么通过conditioning去影响noise的生成?换个说法:怎么根据condition去生成合适的noise?最核心的处理就在Dit block了!
  • 原论文中,Dit block的网络结构有好几种,中间那种就是熟知的cross attention来融合patch和condition,但被设置成了灰色。而左边的adaLN-Zero颜色鲜亮,原论文大概率是推荐这种方式的,那么adaLN-Zero有啥过人之处了?
    • 先看看attention机制:不管是self还是cross attention,都要先生成Q K V三个矩阵,矩阵大小是 seq length * dim,数据量和计算量都不小;而adaLN-Zero机制是把condition通过MLP生成alpha、beta、gamma参数,然后对image的token做scala和shift,完全不涉及Q K V这种矩阵,数据量和计算量小了很多
    • AdaLN-Zero 的全称是 adaptive layer normalization with zero initialization,这里有两个要点:
      • 什么是adaptive自适应?通常指模型能够根据input或condition动态调整自身参数,比如Dit这里的scala和shift,就是根据condition产生的alpha、beta、gamma来做,借此通过condition来控制和影响noise的生成
      • 什么又是zero initialization?这里指的是alpha、beta、gamma初始值是0,让网络恒等映射(输入等于输出),避免初期出现梯度爆炸或弥散,这个思路类似resnet
    •  Pointwise Feedforward 在 Transformer 模型中的作用是对每个输入 token 独立进行non-linear projection。它由两个线性变换和一个激活函数(如 ReLU)组成,作用在每个 token 上,而不考虑tokens之间的关系,因此称为“pointwise”;我个人猜测应该是前面已经做过了attention,tokens之间的关系已经挖掘,所以这里没必要了。这里做mlp,核心还是做特征组合,提取non-linear  feather,为每个token添加更多的特征,增加复杂的特征表达能力!公式为:

  代码层面具体的实现参考:https://github.com/facebookresearch/DiT 的models.py文件:

def modulate(x, shift, scale):#先scala缩放,再shift平移
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

  核心就是这个DiTBlock了:

class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential( # condition做MLP转换
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        # 通过condition生成6个scala和shift参数
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        # 先layer norm, 再scala和shift,最后multihead attention;然后做resnet连接
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        #先layer norm, 再scala和shift,最后pointwise feedforward;然后做resnet连接
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x

  最后总结一下各种generative model的核心思路:

  

 

总结:

1、https://github.com/jingyaogong/minimind-v    大语言模型(LLM)名字虽然带有语言二字,但它们其实与语言关系不大,这只是历史问题,更确切的名字应该是自回归 Transformer 或者其他。 LLM 更多是一种统计建模的通用技术,它们主要通过自回归 Transformer 来模拟 token 流,而这些 token 可以代表文本、图片、音频、动作选择、甚至是分子等任何东西。 因此,只要能将问题转化为模拟一系列离散 token 的流程,理论上都可以应用 LLM 来解决。 实际上,随着大型语言模型技术栈的日益成熟,我们可能会看到越来越多的问题被纳入这种建模范式。也就是说,问题固定在使用 LLM 进行『下一个 token 的预测』,只是每个领域中 token 的用途和含义有所不同

  文本、视频、语音、动作等在人类看来属于「多模态」信号,但所谓的「模态」其实只是人类在信息存储方式上的一种分类概念。 就像.txt.png文件,虽然在视觉呈现和高级表现形式上有所不同,但它们本质上并没有根本区别。 之所以出现「多模态」这个概念,仅仅是因为人类在不同的感知层面上对这些信号的分类需求。 然而,对于机器来说,无论信号来自何种「模态」,最终它们都只是以一串二进制的「单模态」数字序列来呈现。 机器并不会区分这些信号的模态来源,而只是处理和分析这些序列背后所承载的信息内容。

  以image为例,也可以转成NLP一样的token,然后和NLP的token一起进入transformer做处理,vit的核心代码如下:

class ViT(nn.Module):
    def __init__(self,emb_size=16):
        super().__init__()
        self.patch_size=4
        self.patch_count=28//self.patch_size # 7
        # 图片转patch; 这里的out_channels后续类似每个token的dimension,可以一步到位设置大一点,就不需要下面的patch_emb了
        self.conv=nn.Conv2d(in_channels=1,out_channels=self.patch_size**2,kernel_size=self.patch_size,padding=0,stride=self.patch_size)
        self.patch_emb=nn.Linear(in_features=self.patch_size**2,out_features=emb_size)    # patch做emb,让image的格式和nlp的token样本格式保持一直
        self.cls_token=nn.Parameter(torch.rand(1,1,emb_size))   # 分类头输入
        self.pos_emb=nn.Parameter(torch.rand(1,self.patch_count**2+1,emb_size))   # position位置向量 (1,seq_len,emb_size)
        self.tranformer_enc=nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=emb_size,nhead=2,batch_first=True),num_layers=3)   # transformer编码器
        self.cls_linear=nn.Linear(in_features=emb_size,out_features=10) # 手写数字10分类
        
    def forward(self,x): # (batch_size,channel=1,width=28,height=28)
        x=self.conv(x) # (batch_size,channel=16,width=7,height=7) # 通过卷积完成patch; 这里使用卷积的矩阵对原始输入图像做转换适配下游任务
        
        x=x.view(x.size(0),x.size(1),self.patch_count**2)   # (batch_size,channel=16,seq_len=49)
        # 这一步的数据格式其实已经和NLP接近了: batch_size; seq_len相当于每条训练样本token的数量; channel相当于embedding dimension
        # 如果觉得channel(embedding dimension)小了,可以在conv这一步直接把channel(embedding dimension)一步到位,就不需要下面的patch_emb了
        x=x.permute(0,2,1)  # (batch_size,seq_len=49,channel=16)
        # 这里的数据格式就和NLP统一了:batch_size; seq_len相当于每条训练样本token的数量,"pixel"相当于nlp的token; emb_size就是每个token embedding的维度
        x=self.patch_emb(x)   # (batch_size,seq_len=49,emb_size)
        
        cls_token=self.cls_token.expand(x.size(0),1,x.size(2))  # (batch_size,1,emb_size)
        x=torch.cat((cls_token,x),dim=1)   # add [cls] token
        x=self.pos_emb+x
        
        y=self.tranformer_enc(x) # 不涉及padding,所以不需要mask
        return self.cls_linear(y[:,0,:])   # 对[CLS] token输出做分类

 

 

参考:

1、https://www.bilibili.com/video/BV13K421h79z/?spm_id_from=333.337.search-card.all.click&vd_source=241a5bcb1c13e6828e519dd1f78f35b2  【Sora重要技术】复现DiT(Diffusion Transformer)模型

2、https://www.bilibili.com/video/BV12J4m1379T/?spm_id_from=333.337.search-card.all.click&vd_source=241a5bcb1c13e6828e519dd1f78f35b2  14步手搓sora!Diffusion Transformer, DiT工作原理

3、https://www.bilibili.com/video/BV15C411Y7cv/?spm_id_from=333.337.search-card.all.click&vd_source=241a5bcb1c13e6828e519dd1f78f35b2   Sora底层技术架构:Diffusion Transformer的论文、项目和源码

5、https://arxiv.org/pdf/2212.09748    Scalable Diffusion Models with Transformers
6、https://github.com/owenliang/mnist-vit  这个项目不错,我pc的GPU都能运行

 

  

 

posted @ 2024-10-01 20:01  第七子007  阅读(185)  评论(0编辑  收藏  举报