深度学习(VAE)

变分自编码器(VAE,Variational Auto-Encoder)是一种生成模型,它通过学习数据的潜在表示来生成新的样本。

在学习潜空间时,需要保持生成样本与真实数据的相似性,并尽量让潜变量的分布接近标准正态分布。

VAE的基本结构:

1. 编码器(Encoder):将输入数据转换为潜在空间的分布,输出潜在变量的均值和方差。

2. 重参数化层(Reparameterization Layer):从编码器输出的均值和方差中进行重参数化采样,生成潜在变量。

3. 解码器(Decoder):接收潜在变量并将其转换回原始数据的分布。

为了让生成样本接近原始数据,最终loss是样本与真实数据相似度和潜变量与标准高斯分布相似度之和。

生成样本和真实数据相似度可以通过mse计算。

潜变量与标准高斯分布相似度可以通过KL散度计算。

下面是两个高斯分布计算KL散度的推导:

设其中一个为标准高斯函数:

下面代码是用FashionMNIST作为数据集,生成样本的示例:

复制代码
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms,datasets
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#dataset = datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
dataset = datasets.FashionMNIST(root='./fasion_data',train=True,transform=transforms.ToTensor(),download=True)

data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=128, 
                                          shuffle=True)

class VAE(nn.Module):
    def __init__(self, image_size=784, h=400, z=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h)
        self.fc2 = nn.Linear(h, z)
        self.fc3 = nn.Linear(h, z)

        self.fc4 = nn.Linear(z, h)
        self.fc5 = nn.Linear(h, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc2(h)
        log_var = self.fc3(h)
        return mu,log_var 
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        reconst_x = F.sigmoid(self.fc5(h))
        return reconst_x
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        reconst_x = self.decode(z)
        return reconst_x, mu, log_var

def loss_function(reconst_x, x, mu, log_var): 
    mse = F.binary_cross_entropy(reconst_x, x, size_average=False)
    kld = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return mse+kld


model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    for i, (x, _) in enumerate(data_loader):

        x = x.to(device).view(-1, 784)
        reconst_x, mu, log_var = model(x)
     
        loss = loss_function(reconst_x,x,mu,log_var) 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print("epoch : ",epoch, "loss:", loss.item())
    
    with torch.no_grad():
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join('./', '{}.png'.format(epoch)))
复制代码

 结果如下:

posted @   Dsp Tian  阅读(121)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示