PyTorch实现简单的变分自动编码器VAE
在上一篇博客中我们介绍并实现了自动编码器,本文将用PyTorch实现变分自动编码器(Variational AutoEncoder, VAE)。自动变分编码器原理与一般的自动编码器的区别在于需要在编码过程增加一点限制,迫使它生成的隐含向量能够粗略的遵循标准正态分布。这样一来,当需要生成一张新图片时,只需要给解码器一个标准正态分布的隐含随机向量就可以了。
在实际操作中,实际上不是生成一个隐含向量,而是生成两个向量:一个表示均值,一个表示标准差,然后通过这两个统计量合成隐含向量,用一个标准正态分布先乘标准差再加上均值就行了。具体关于变分自动编码器的内容,可参考廖星宇的《深度学习之PyTorch》的第六章,下面的代码也是来自这个资料,但本文对原代码做了一点改动。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | import os import torch import torch.nn.functional as F from torch import nn from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision import transforms as tfs from torchvision.utils import save_image # Hyper parameters EPOCH = 1 LR = 1e - 3 BATCHSIZE = 128 im_tfs = tfs.Compose([ tfs.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] tfs.Normalize([ 0.5 ], [ 0.5 ]) # 把[0.0, 1.0]的数据扩大范围到[-1., 1] ]) train_set = MNIST( root = '/Users/wangpeng/Desktop/all/CS/Courses/Deep Learning/mofan_PyTorch/mnist/' , # mnist has been downloaded before, use it directly train = True , transform = im_tfs, ) train_loader = DataLoader(train_set, batch_size = BATCHSIZE, shuffle = True ) class VAE(nn.Module): def __init__( self ): super (VAE, self ).__init__() self .fc1 = nn.Linear( 784 , 400 ) self .fc21 = nn.Linear( 400 , 20 ) # mean self .fc22 = nn.Linear( 400 , 20 ) # var self .fc3 = nn.Linear( 20 , 400 ) self .fc4 = nn.Linear( 400 , 784 ) def encode( self , x): h1 = F.relu( self .fc1(x)) return self .fc21(h1), self .fc22(h1) def reparametrize( self , mu, logvar): std = logvar.mul( 0.5 ).exp_() # 矩阵点对点相乘之后再把这些元素作为e的指数 eps = torch.FloatTensor(std.size()).normal_() # 生成随机数组 if torch.cuda.is_available(): eps = eps.cuda() return eps.mul(std).add_(mu) # 用一个标准正态分布乘标准差,再加上均值,使隐含向量变为正太分布 def decode( self , z): h3 = F.relu( self .fc3(z)) return torch.tanh( self .fc4(h3)) def forward( self , x): mu, logvar = self .encode(x) # 编码 z = self .reparametrize(mu, logvar) # 重新参数化成正态分布 return self .decode(z), mu, logvar # 解码,同时输出均值方差 net = VAE() # 实例化网络 if torch.cuda.is_available(): net = net.cuda() reconstruction_function = nn.MSELoss(size_average = False ) def loss_function(recon_x, x, mu, logvar): """ recon_x: generating images x: origin images mu: latent mean logvar: latent log variance """ MSE = reconstruction_function(recon_x, x) # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD_element = mu. pow ( 2 ).add_(logvar.exp()).mul_( - 1 ).add_( 1 ).add_(logvar) KLD = torch. sum (KLD_element).mul_( - 0.5 ) # KL divergence return MSE + KLD optimizer = torch.optim.Adam(net.parameters(), lr = LR) def to_img(x): # x shape (bachsize, 28*28), x中每个像素点的大小范围[-1., 1.] ''' 定义一个函数将最后的结果转换回图片 ''' x = 0.5 * (x + 1. ) x = x.clamp( 0 , 1 ) x = x.view(x.shape[ 0 ], 1 , 28 , 28 ) return x for epoch in range (EPOCH): for iteration, (im, y) in enumerate (train_loader): im = im.view(im.shape[ 0 ], - 1 ) if torch.cuda.is_available(): im = im.cuda() recon_im, mu, logvar = net(im) loss = loss_function(recon_im, im, mu, logvar) / im.shape[ 0 ] # 将 loss 平均 optimizer.zero_grad() loss.backward() optimizer.step() if iteration % 100 = = 0 : print ( 'epoch: {:2d} | iteration: {:4d} | Loss: {:.4f}' . format (epoch, iteration, loss.data.numpy())) save = to_img(recon_im.cpu().data) if not os.path.exists( './vae_img' ): os.mkdir( './vae_img' ) save_image(save, './vae_img/image_{}_{}.png' . format (epoch, iteration)) # test code = torch.randn( 1 , 20 ) # 随机给一个符合正态分布的张量 out = net.decode(code) img = to_img(out) save_image(img, './vae_img/test_img.png' ) |
分类:
Deep Learning
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 单元测试从入门到精通