VAE(变分自编码器的torch实现) —— jupyter实现(注意tqdm模块不同)

简单实现了torch版本的变分自编码器

参考大佬TensorFlow版本的VAE:膜拜大佬

import os
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
import torch
from torchvision import datasets, transforms
import torch.nn as nn
from time import sleep
from tqdm.notebook import tqdm
class CFG:
    batch_size = 10
    z_dim = 10
    epoch = 1000
    lr = 0.0001
mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= CFG.batch_size, shuffle=True)
for i in train_loader:
    print(i[0].shape)
    break
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.e1 = nn.Linear(784, 128)
        self.e2 = nn.Linear(128, CFG.z_dim)
        self.e3 = nn.Linear(128, CFG.z_dim)
        
        self.fc4 = nn.Linear(10, 128)
        self.fc5 = nn.Linear(128, 784)
    
    def reparameterize(self, mean, log_var):
        eps = torch.randn(log_var.shape)
        std = torch.exp(log_var)**0.5
        z = mean + eps * std
        return z
    
    def encoder(self, inputs):
        h = self.e1(inputs)
        h = torch.nn.ReLU()(h)
        mean = self.e2(h)
        log_var = self.e3(h)
        return mean, log_var
    
    def decoder(self, z):
        return self.fc5(torch.nn.ReLU()(self.fc4(z)))
    
    def forward(self, inputs):
        mean, log_var = self.encoder(inputs)
        z = self.reparameterize(mean, log_var)
        x_hat = self.decoder(z)
        x_hat = torch.sigmoid(x_hat)
        return x_hat, mean, log_var
model = VAE()
model.train()
cross_entroy_loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr)
for epoch in range(1, CFG.epoch + 1):
    loop = tqdm((train_loader), total = len(train_loader))
    for x in loop:
        x_, y_ = x[0].reshape(-1, 784), x[1]
        optimizer.zero_grad()
        x_rec_logits, mean, log_var = model(x_)
        rec_loss = cross_entroy_loss(x_, x_rec_logits)
        rec_loss = torch.mean(rec_loss)
        kl_div = -0.5 * (log_var + 1 - mean ** 2 - torch.exp(log_var))
        kl_div = torch.mean(kl_div) / x_.shape[0]
        
        loss = rec_loss + 1.0 * kl_div
        loss.backward()
        optimizer.step()
        
        loop.set_description(f'Epoch [{epoch}/{CFG.epoch}]')
        loop.set_postfix(loss=loss.item(), Kl_div = kl_div.item(),rec_loss = rec_loss.item())
        sleep(0.05)
posted @   麦扣  阅读(275)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 零经验选手,Compose 一天开发一款小游戏!
· 一起来玩mcp_server_sqlite,让AI帮你做增删改查!!
点击右上角即可分享
微信分享提示