LLM大模型: Denoising Diffusion Probabilistic Models 原理解析与核心代码
根据文本生成图片是AI的核心应用之一,2020年后主流的生成方式都是基于Denoising Diffusion Probabilistic Models原理的,逐渐替代了之前使用GAN的方式生成图片!那么DDPM为啥能取代GAN了?其优势在哪?或者说GAN的劣势在哪?
1、CLIP模型都知道吧? text和image都通过各自的encoder转成embedding,然后两个embedding计算cosin距离,距离近的就配对,距离远的就不是匹配的。既然text都能通过encoder得到embedding了(相当于把text去掉冗余、保留核心信息),为啥不直接使用这个embedding训练decoder来生成图片了?
假设有一个text:"A cat sitting on a mat"。通过text encoder(如BERT或CLIP),我们可以将这个文本编码为一个d维向量(例如,d = 512)。在简单的架构下,我希望用这个512维的embedding直接通过解码器生成图片,这意味着:
- 输入:512维embedding
- 输出:比如生成一张分辨率为256x256的图片(像素数为65536),每个像素可能有RGB三个通道,也就是65536x3 = 196608个值。
直接从512维的embedding映射到一个包含196608个浮点值的图片,面临以下问题:
-
复杂的映射关系:图片的空间结构非常复杂,直接从512维的低维空间生成196608维的高维图片,要求模型同时学习全局和局部的图像信息。但模型的神经元是有限的,信息承载能力有限,生成图片容易缺乏多样性或质量不高。
具体的数值问题:即使decoder只是一个简单的线性网络,它的输出维度为196608。为了保持图像生成的细节和准确性,需要一个非常大的权重矩阵,维度大概是512 x 196608,直接计算和优化非常困难。如果是多层DNN,再加上激活函数,计算量就更大了!
-
缺乏生成过程的控制:直接从embedding生成图像无法逐步生成细节,难以在生成过程中进行调整和优化,导致生成的图像质量可能较差(生成的图片模糊、缺乏细节或结构失真)。
为了便于理解,这里做个对比:所有的有监督模型做back proporgation时,都会根据loss计算gradient,然后更新参数W。这里参数W的计算就有讲究了:
梯度在更新的时候,有个学习率learning rate 的超参数,一般设置为10e-3或10e-4,那么问题来了:设置learning rate后,每次更新权重的数值就小了很多,为啥不去掉learning rate直接一步到位更新权重了?
为了尽可能让loss小,参数更新尽量平滑,lr就是控制参数更新的关键参数;lr过大会导致w来回震荡,收敛不稳定,甚至可能找不到最低点,所以要通过梯度下降这种“小步快跑”的方式逐渐逼近最低点!这里的情况和text生成image类似:如果直接用text的embedding一步到位生成image,可能导致image的细节无法控制,进而影响image的质量,所以要想办法逐步生成高质量的图片!
2、(1)接下来的问题就是怎么逐步生成高质量的图片?又回到梯度跟新的这个类比:整个网络参数W一般情况下都是随机初始化的,随机值~N(0,1),然后根据loss逐步update网络参数,直到达到迭代停止的条件,得到最终的参数W;那么图片生成是不是也能借鉴这种“从初始化的随机值一步一步迭代得到最终值”的思路了?我先随机生成一张图片(当然图片的像素值都是随机的),既然是随机的,人肯定看不懂这种图片,然后通过某些方式一步一步把随机生成的图片变成高质量的图片行不?整个流程如下图所示:
刚开始随机生成一张全是noise的图片,然后通过某些”精雕细琢“的算法一点一点地把text要求的图像呈现出来!整个流程要持续几百甚至上千步,有点类似于雕刻:从整块不成型的大石头开始,去掉不需要的部分,剩下的就是成型后所需的形状了!
(2)怎么一步一步从随机noise中去掉多余的noise,得到想要的图片了?既然要去掉多余的noise,肯定要先精准地找到noise,否则怎么知道去掉啥了?所以整个生成图片的核心问题就转换成了:怎么精准找到noise!好比雕刻的时候怎么精准地找到多余的部分!
找noise这种事肯定是要借助神经网络的,这就涉及到标签了!对于只有样本数据集的情况而言,生成训练样本的方式只能如下了:
基于Gaussian distribution生成noise图片,和原图片叠加,生成第一张带有noise的图片,然后继续生成第二章noise,继续叠加,直到step达到指定步数为止,这就完整生成了训练样本!随机生成的noise图片就是需要神经网络预测的noise!这就是所谓的DDPM forward noise的过程!
inference的时候直接用随机生成的noise图片一步一步减去noise,最终得到了原始的图片!这就是DDPM 反向denoise的过程!
(3)具体工程实现时:因为noise都是Gaussian distribution,对于不同程度的加噪,可以直接设置不同的权重。比如第一次加噪,noise的权重小点;对于最后一次加噪,noise的权重非常大,这个是可以通过超参数控制的,就是下面的alpha bar!
用正式的数学公式总结一下就是:
一句话总结:从X0加噪到Xt,其实一步就够了,但是要注意控制好噪音的权重!
(4)对于去噪denoise,把上面加噪的过程反过来就行,所以核心就是每步都要求正确预测noise,然后用上一步的图片减去noise图片
3、既然整个image生成过程的核心是准确预测每一步的noise,那么预测的目标自然就是每个步骤的noise咯!预测需要通过DNN来完成,已下图为例:预测step 2的noise,那么输入就是step 2加噪的图片、step 2数字、text,输出就是预测的step 2的noise,然后和真实的noise比对,产生的loss用于更新predictor的参数!
具体到表达式就是这样的啦:
原论文的整个trainning是酱紫的:
4、 代码参考:https://github.com/owenliang/pytorch-diffusion
(1)diffusion加噪:
import torch from config import * from dataset import train_dataset,tensor_to_pil import matplotlib.pyplot as plt # 前向diffusion计算参数 betas=torch.linspace(0.0001,0.02,T) # (T,) alphas=1-betas # (T,) alphas_cumprod=torch.cumprod(alphas,dim=-1) # alpha_t累乘 (T,) [a1,a2,a3,....] -> [a1,a1*a2,a1*a2*a3,.....] alphas_cumprod_prev=torch.cat((torch.tensor([1.0]),alphas_cumprod[:-1]),dim=-1) # alpha_t-1累乘 (T,), [1,a1,a1*a2,a1*a2*a3,.....] variance=(1-alphas)*(1-alphas_cumprod_prev)/(1-alphas_cumprod) # denoise用的方差 (T,) # 执行前向加噪 def forward_diffusion(batch_x,batch_t): # batch_x: (batch,channel,width,height), batch_t: (batch_size,) batch_noise_t=torch.randn_like(batch_x) # 为每张图片生成第t步的高斯噪音 (batch,channel,width,height) batch_alphas_cumprod=alphas_cumprod.to(DEVICE)[batch_t].view(batch_x.size(0),1,1,1) batch_x_t=torch.sqrt(batch_alphas_cumprod)*batch_x+torch.sqrt(1-batch_alphas_cumprod)*batch_noise_t # 基于公式直接生成第t步加噪后图片 return batch_x_t,batch_noise_t if __name__=='__main__': batch_x=torch.stack((train_dataset[0][0],train_dataset[1][0]),dim=0).to(DEVICE) # 2个图片拼batch, (2,1,48,48) # 加噪前的样子 plt.figure(figsize=(10,10)) plt.subplot(1,2,1) plt.imshow(tensor_to_pil(batch_x[0])) plt.subplot(1,2,2) plt.imshow(tensor_to_pil(batch_x[1])) plt.show() batch_x=batch_x*2-1 # [0,1]像素值调整到[-1,1]之间,以便与高斯噪音值范围匹配 batch_t=torch.randint(0,T,size=(batch_x.size(0),)).to(DEVICE) # 每张图片随机生成diffusion步数 # batch_t=torch.tensor([5,100],dtype=torch.long) print('batch_t:',batch_t) batch_x_t,batch_noise_t=forward_diffusion(batch_x,batch_t) print('batch_x_t:',batch_x_t.size()) print('batch_noise_t:',batch_noise_t.size()) # 加噪后的样子 plt.figure(figsize=(10,10)) plt.subplot(1,2,1) plt.imshow(tensor_to_pil((batch_x_t[0]+1)/2)) plt.subplot(1,2,2) plt.imshow(tensor_to_pil((batch_x_t[1]+1)/2)) plt.show()
加噪前:
加噪后:时刻T是随机生成的
(2)transformer有position embedding,主要是区分token的位置。llama系列采用的是rotary embedding,这里借鉴类似的思路,对T时刻的做embedding!
import torch from torch import nn import math from config import * class TimePositionEmbedding(nn.Module): def __init__(self,emb_size): super().__init__() self.half_emb_size=emb_size//2 # arange:[e^(0*-1*math.log(1000)/3),e^(1*-1*math.1og(1000)/3),e^(2*-1*math,1og(1000)/3),e^(3*-1*math.log(1000)/3)] half_emb=torch.exp(torch.arange(self.half_emb_size)*(-1*math.log(10000)/(self.half_emb_size-1))) self.register_buffer('half_emb',half_emb) #固化参数,不求梯度 def forward(self,t): #[631,65] t=t.view(t.size(0),1) #形状保持一致 #[[631],[65]] # [fe^(e*-1*math.log(1000)/3),e^(1*-1*math.10g(1000)/3),e^(2*-1*math.log(1000)/3),e^(3*-1*math.1og(1000)/3)],...] half_emb=self.half_emb.unsqueeze(0).expand(t.size(0),self.half_emb_size) half_emb_t=half_emb*t embs_t = torch.cat((half_emb_t.sin(), half_emb_t.cos()), dim=-1) return embs_t if __name__=='__main__': time_pos_emb=TimePositionEmbedding(8).to(DEVICE) t=torch.randint(0,T,(2,)).to(DEVICE) # 随机2个图片的time时刻 embs_t=time_pos_emb(t) print(embs_t)
位置编码最核心的目的就是让两个不同t的embedding不一样,利于区分,比如下面的这种就是ok的:41和353两个位置的embedding就不同!
tensor([ 41, 353]) tensor([[-0.1586, 0.9453, 0.0882, 0.0041, -0.9873, -0.3262, 0.9961, 1.0000], [ 0.9093, -0.6263, 0.6893, 0.0353, 0.4161, -0.7796, 0.7245, 0.9994]])
(3)需要使用unet,根据加噪的图片和T时刻生成noise图片
代码如下:想定义一个conv block
from torch import nn from cross_attn import CrossAttention class ConvBlock(nn.Module): def __init__(self,in_channel,out_channel,time_emb_size,qsize,vsize,fsize,cls_emb_size): super().__init__() self.seq1 = nn.Sequential( nn.Conv2d(in_channel,out_channel,kernel_size=3,stride=1,padding=1), # 改通道数,不改大小 nn.BatchNorm2d(out_channel), nn.ReLU(), ) self.time_emb_linear=nn.Linear(time_emb_size,out_channel) # Time时刻emb转成channel宽,加到每个像素点上 self.relu=nn.ReLU() self.seq2=nn.Sequential( nn.Conv2d(out_channel,out_channel,kernel_size=3,stride=1,padding=1), # 不改通道数,不改大小 nn.BatchNorm2d(out_channel), nn.ReLU(), ) # 像素做Query,计算对分类ID的注意力,实现分类信息融入图像,不改变图像形状和通道数 self.crossattn=CrossAttention(channel=out_channel,qsize=qsize,vsize=vsize,fsize=fsize,cls_emb_size=cls_emb_size) def forward(self,x,t_emb,cls_emb): # t_emb: (batch_size,time_emb_size) x=self.seq1(x) # 改通道数,不改大小 t_emb=self.relu(self.time_emb_linear(t_emb)).view(x.size(0),x.size(1),1,1) # t_emb: (batch_size,out_channel,1,1) output=self.seq2(x+t_emb) # 不改通道数,不改大小 return self.crossattn(output,cls_emb) # 图像和引导向量做attention
unet预测noise:
from torch import nn from dataset import train_dataset from config import * from diffusion import forward_diffusion from time_position_emb import TimePositionEmbedding from conv_block import ConvBlock class UNet(nn.Module): def __init__(self,img_channel,channels=[64, 128, 256, 512, 1024],time_emb_size=256,qsize=16,vsize=16,fsize=32,cls_emb_size=32): super().__init__() channels=[img_channel]+channels # time转embedding self.time_emb=nn.Sequential( TimePositionEmbedding(time_emb_size), nn.Linear(time_emb_size,time_emb_size), nn.ReLU(), ) # 引导词cls转embedding self.cls_emb=nn.Embedding(10,cls_emb_size) # 每个encoder conv block增加一倍通道数 self.enc_convs=nn.ModuleList() for i in range(len(channels)-1): self.enc_convs.append(ConvBlock(channels[i],channels[i+1],time_emb_size,qsize,vsize,fsize,cls_emb_size)) # 每个encoder conv后马上缩小一倍图像尺寸,最后一个conv后不缩小 self.maxpools=nn.ModuleList() for i in range(len(channels)-2): self.maxpools.append(nn.MaxPool2d(kernel_size=2,stride=2,padding=0)) # 每个decoder conv前放大一倍图像尺寸,缩小一倍通道数 self.deconvs=nn.ModuleList() for i in range(len(channels)-2): self.deconvs.append(nn.ConvTranspose2d(channels[-i-1],channels[-i-2],kernel_size=2,stride=2)) # 每个decoder conv block减少一倍通道数 self.dec_convs=nn.ModuleList() for i in range(len(channels)-2): self.dec_convs.append(ConvBlock(channels[-i-1],channels[-i-2],time_emb_size,qsize,vsize,fsize,cls_emb_size)) # 残差结构 # 还原通道数,尺寸不变 self.output=nn.Conv2d(channels[1],img_channel,kernel_size=1,stride=1,padding=0) def forward(self,x,t,cls): # cls是引导词(图片分类ID) # time做embedding t_emb=self.time_emb(t) # cls做embedding cls_emb=self.cls_emb(cls) # encoder阶段 residual=[] for i,conv in enumerate(self.enc_convs): x=conv(x,t_emb,cls_emb) if i!=len(self.enc_convs)-1: residual.append(x) x=self.maxpools[i](x) # decoder阶段 for i,deconv in enumerate(self.deconvs): x=deconv(x) residual_x=residual.pop(-1) x=self.dec_convs[i](torch.cat((residual_x,x),dim=1),t_emb,cls_emb) # 残差用于纵深channel维 return self.output(x) # 还原通道数 if __name__=='__main__': batch_x=torch.stack((train_dataset[0][0],train_dataset[1][0]),dim=0).to(DEVICE) # 2个图片拼batch, (2,1,48,48) batch_x=batch_x*2-1 # 像素值调整到[-1,1]之间,以便与高斯噪音值范围匹配 batch_cls=torch.tensor([train_dataset[0][1],train_dataset[1][1]],dtype=torch.long).to(DEVICE) # 引导ID batch_t=torch.randint(0,T,size=(batch_x.size(0),)).to(DEVICE) # 每张图片随机生成diffusion步数 batch_x_t,batch_noise_t=forward_diffusion(batch_x,batch_t) print('batch_x_t:',batch_x_t.size()) print('batch_noise_t:',batch_noise_t.size()) unet=UNet(img_channel=1).to(DEVICE) batch_predict_noise_t=unet(batch_x_t,batch_t,batch_cls) print('batch_predict_noise_t:',batch_predict_noise_t.size())
(4)万事俱备,只剩train了!
from config import * from torch.utils.data import DataLoader from dataset import train_dataset from unet import UNet from diffusion import forward_diffusion import torch from torch import nn import os from torch.utils.tensorboard import SummaryWriter EPOCH=200 BATCH_SIZE=800 dataloader=DataLoader(train_dataset,batch_size=BATCH_SIZE,num_workers=4,persistent_workers=True,shuffle=True) # 数据加载器 try: model=torch.load('model.pt') except: model=UNet(1).to(DEVICE) # 噪音预测模型 optimizer=torch.optim.Adam(model.parameters(),lr=0.001) # 优化器 loss_fn=nn.L1Loss() # 损失函数(绝对值误差均值);都是图片,直接对比像素 writer = SummaryWriter() if __name__=='__main__': model.train() n_iter=0 for epoch in range(EPOCH): last_loss=0 for batch_x,batch_cls in dataloader: # 图像的像素范围转换到[-1,1],和高斯分布对应 batch_x=batch_x.to(DEVICE)*2-1 # 引导分类ID batch_cls=batch_cls.to(DEVICE) # 为每张图片生成随机t时刻 batch_t=torch.randint(0,T,(batch_x.size(0),)).to(DEVICE) # 生成t时刻的加噪图片和对应噪音 batch_x_t,batch_noise_t=forward_diffusion(batch_x,batch_t) # 模型预测t时刻的噪音 batch_predict_t=model(batch_x_t,batch_t,batch_cls) # 求损失 loss=loss_fn(batch_predict_t,batch_noise_t) # 优化参数 optimizer.zero_grad() loss.backward() optimizer.step() last_loss=loss.item() writer.add_scalar('Loss/train', last_loss, n_iter) n_iter+=1 print('epoch:{} loss={}'.format(epoch,last_loss)) torch.save(model,'model.pt.tmp')#先写入临时文件 os.replace('model.pt.tmp','model.pt')#原子操作,确保数据落盘安全
用cpu尝试训练:内存消耗并不多,就是cpu撑爆了!
预测的noise和真实noise的loss:
epoch到50多以后基本稳定没有再下降了:
model.pt文件也生成了:
参考:
1、https://arxiv.org/abs/2006.11239 Denoising Diffusion Probabilistic Models
2、https://github.com/hojonathanho/diffusion
3、https://www.bilibili.com/video/BV1g14y1X76j/?spm_id_from=333.337.search-card.all.click&vd_source=241a5bcb1c13e6828e519dd1f78f35b2
4、https://www.bilibili.com/video/BV1R14y1D7kx/?p=2&spm_id_from=pageDriver&vd_source=241a5bcb1c13e6828e519dd1f78f35b2 how-diffusion-models-work
https://www.deeplearning.ai/short-courses/how-diffusion-models-work/
5、https://www.bilibili.com/video/BV1im4reKEbX/?spm_id_from=333.788.recommend_more_video.3&vd_source=241a5bcb1c13e6828e519dd1f78f35b2 基于Pytorch从零实现Stable Diffusion模型 by Umar Jamil
6、https://zhuanlan.zhihu.com/p/590840909 https://github.com/owenliang/pytorch-diffusion 扩散模型DDPM浅析 https://cvmart.net/community/detail/6936