pytorch(二十六):自动编码器

一、自动编码器

1、AE.py

import torch
from torch import nn

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()

        #[b, 784] => [b, 20]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )

        #[b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid(),
        )

    def forward(self, x):
        """
        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.shape[0]
        #flatten
        x = x.view(batchsz, 784)
        #encoder
        x = self.encoder(x)
        #decoder
        x = self.decoder(x)
        #reshape
        x = x.view(batchsz,1, 28, 28)

        return x, None

2、main.py

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from auto_encoder import AE
from torch import nn, optim
import visdom
def main():
    mnist_train = datasets.MNIST("mnist", True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST("mnist", False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = iter(mnist_train).__next__()
    print(x.shape)

    model = AE()
    criton = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    viz = visdom.Visdom()
    for epoch in range(1000):
        for batchidx, (x, _) in enumerate(mnist_train):
            #[b, 1, 28, 28]
            x_hat, _ = model(x)
            loss = criton(x_hat, x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch, "loss:",  loss.item())
        x, _ = iter(mnist_test).__next__()
        with torch.no_grad():
            x_hat, _ = model(x)
        viz.images(x, nrow=8, win="x", opts=dict(title = "x"))
        viz.images(x_hat, nrow=8, win="x_hat", opts=dict(title="x_hat"))

if __name__ == '__main__':
    main()

二、变分自动编码器编码器

1、模型

import torch
from torch import nn
import numpy as np

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        #[b, 784] => [b, 20]
        #u:[b, 10]
        #sigma:[b, 10]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )

        #[b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid(),
        )

    def forward(self, x):
        """
        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.shape[0]
        #flatten
        x = x.view(batchsz, 784)
        #encoder
        #[b, 20], including mean and sigma
        h_ = self.encoder(x)
        #[b, 20] => [b, 10] and [b, 10]
        mu, sigma = h_.chunk(2, dim = 1)
        # reparametrize trick, epison~N(0, 1), [b, 10]
        h = mu + sigma * torch.randn_like(sigma)

        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

        #decoder
        x = self.decoder(h)
        #reshape
        x = x.view(batchsz,1, 28, 28)

        return x, kld

2、运行程序

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from vae import VAE
from torch import nn, optim
import visdom
def main():
    mnist_train = datasets.MNIST("mnist", True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST("mnist", False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = iter(mnist_train).__next__()
    print(x.shape)

    model = VAE()
    criton = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    viz = visdom.Visdom()
    for epoch in range(1000):
        for batchidx, (x, _) in enumerate(mnist_train):
            #[b, 1, 28, 28]
            x_hat, kld = model(x)
            loss = criton(x_hat, x)

            if kld is not  None:
                loss = loss + 1.0 * kld

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch, "loss:",  loss.item(), kld.item())
        x, _ = iter(mnist_test).__next__()
        with torch.no_grad():
            x_hat, _ = model(x)
        viz.images(x, nrow=8, win="x", opts=dict(title = "x"))
        viz.images(x_hat, nrow=8, win="x_hat", opts=dict(title="x_hat"))

if __name__ == '__main__':
    main()

 

 

posted @ 2021-06-30 19:48  jasonzhangxianrong  阅读(326)  评论(0编辑  收藏  举报