VAE(变分自编码器的torch实现) —— jupyter实现(注意tqdm模块不同)
简单实现了torch版本的变分自编码器
参考大佬TensorFlow版本的VAE:膜拜大佬
import os
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
import torch
from torchvision import datasets, transforms
import torch.nn as nn
from time import sleep
from tqdm.notebook import tqdm
class CFG:
batch_size = 10
z_dim = 10
epoch = 1000
lr = 0.0001
mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= CFG.batch_size, shuffle=True)
for i in train_loader:
print(i[0].shape)
break
class VAE(nn.Module):
def __init__(self):
super().__init__()
self.e1 = nn.Linear(784, 128)
self.e2 = nn.Linear(128, CFG.z_dim)
self.e3 = nn.Linear(128, CFG.z_dim)
self.fc4 = nn.Linear(10, 128)
self.fc5 = nn.Linear(128, 784)
def reparameterize(self, mean, log_var):
eps = torch.randn(log_var.shape)
std = torch.exp(log_var)**0.5
z = mean + eps * std
return z
def encoder(self, inputs):
h = self.e1(inputs)
h = torch.nn.ReLU()(h)
mean = self.e2(h)
log_var = self.e3(h)
return mean, log_var
def decoder(self, z):
return self.fc5(torch.nn.ReLU()(self.fc4(z)))
def forward(self, inputs):
mean, log_var = self.encoder(inputs)
z = self.reparameterize(mean, log_var)
x_hat = self.decoder(z)
x_hat = torch.sigmoid(x_hat)
return x_hat, mean, log_var
model = VAE()
model.train()
cross_entroy_loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr)
for epoch in range(1, CFG.epoch + 1):
loop = tqdm((train_loader), total = len(train_loader))
for x in loop:
x_, y_ = x[0].reshape(-1, 784), x[1]
optimizer.zero_grad()
x_rec_logits, mean, log_var = model(x_)
rec_loss = cross_entroy_loss(x_, x_rec_logits)
rec_loss = torch.mean(rec_loss)
kl_div = -0.5 * (log_var + 1 - mean ** 2 - torch.exp(log_var))
kl_div = torch.mean(kl_div) / x_.shape[0]
loss = rec_loss + 1.0 * kl_div
loss.backward()
optimizer.step()
loop.set_description(f'Epoch [{epoch}/{CFG.epoch}]')
loop.set_postfix(loss=loss.item(), Kl_div = kl_div.item(),rec_loss = rec_loss.item())
sleep(0.05)