transgan_pytorch

transgan_pytorch

Implementation of the paper:

Yifan Jiang, Shiyu Chang and Zhangyang Wang. TransGAN: Two Pure Transformers Can Make One Strong GAN, and That Can Scale Up.

网络结构

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

"""
Paper uses Vaswani (2017) Attention with minimal changes.
Multi-head self-attention with a feed-forward MLP
with GELU non-linearity. Layer normalisation is used
before each segment and employs residual skip connections.
"""
class Attention(nn.Module):
    def __init__(self, D, heads=8):
        super().__init__()
        self.D = D
        self.heads = heads

        assert (D % heads == 0), "Embedding size should be divisble by number of heads"
        self.head_dim = self.D // heads

        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.H = nn.Linear(self.D, self.D)

    def forward(self, Q, K, V, mask):
        batch_size = Q.shape[0]
        q_len, k_len, v_len = Q.shape[1], K.shape[1], V.shape[1]

        Q = Q.reshape(batch_size, q_len, self.heads, self.head_dim)
        K = K.reshape(batch_size, k_len, self.heads, self.head_dim)
        V = V.reshape(batch_size, v_len, self.heads, self.head_dim)

        # performing batch-wise matrix multiplication
        raw_scores = torch.einsum("bqhd,bkhd->bhqk", [Q, K])

        # shut off triangular matrix with very small value
        scores = raw_scores.masked_fill(mask == 0, -np.inf) if mask else raw_scores

        attn = torch.softmax(scores / np.sqrt(self.D), dim=3)
        attn_output = torch.einsum("bhql,blhd->bqhd", [attn, V])
        attn_output = attn_output.reshape(batch_size, q_len, self.D)

        output = self.H(attn_output)

        return output

class EncoderBlock(nn.Module):
    def __init__(self, D, heads, p, fwd_exp):
        super().__init__()
        self.mha = Attention(D, heads)
        self.drop_prob = p
        self.n1 = nn.LayerNorm(D)
        self.n2 = nn.LayerNorm(D)
        self.mlp = nn.Sequential(
            nn.Linear(D, fwd_exp*D),
            nn.ReLU(),
            nn.Linear(fwd_exp*D, D),
        )
        self.dropout = nn.Dropout(p)

    def forward(self, Q, K, V, mask):
        attn = self.mha(Q, K, V, mask)

        """
        Layer normalisation with residual connections
        """
        x = self.n1(attn + Q)
        x = self.dropout(x)
        forward = self.mlp(x)
        x = self.n2(forward + x)
        out = self.dropout(x)

        return out

class MLP(nn.Module):
    def __init__(self, noise_w, noise_h, channels):
        super().__init__()
        self.l1 = nn.Linear(
                    noise_w*noise_h*channels, 
                    (8*8)*noise_w*noise_h*channels, 
                    bias=False
                )

    def forward(self, x):
        out = self.l1(x)
        return out

class PixelShuffle(nn.Module):
    def __init__(self):
        super().__init__()
        pass

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = MLP(32, 32, 1)
        
        # stage 1
        self.s1_enc = nn.ModuleList([
                        EncoderBlock(1024*8*8)
                        for _ in range(5)
                    ])

        # stage 2
        self.s2_pix_shuffle = PixelShuffle()
        self.s2_enc = nn.ModuleList([
                        EncoderBlock(256*16*16)
                        for _ in range (4)
                    ])

        # stage 3
        self.s3_pix_shuffle = PixelShuffle()
        self.s3_enc = nn.ModuleList([
                        EncoderBlock(64*32*32)
                        for _ in range(2)
                    ])

        # stage 4
        self.linear = nn.Linear(32*32*64, 32*32*3)

    def forward(self, noise):
        x = self.mlp(noise)
        for layer in self.s1_enc:
            x = layer(x)
        
        x = self.s2_pix_shuffle(x)
        for layer in self.s2_enc:
            x = layer(x)

        x - self.s3_pix_shuffle(x)
        for layer in self.s3_enc:
            x = layer(x)

        img = self.linear(x)

        return img

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.l1 = nn.Linear(32*32*3, (8*8+1)*384)
        self.s2_enc = nn.ModuleList([
                        EncoderBlock((8*8+1)*284)
                        for _ in range(7)
                    ])

        self.classification_head = nn.Linear(1*384, 1)

    def forward(self, img):
        x = self.l1(img)
        for layer in self.s2_enc:
            x = layer(x)

        logits = self.classification_head(x)
        pred = F.softmax(logits)
        
        return pred

class TransGAN_S(nn.Module):
    def __init__(self):
        super().__init__()
        self.gen = Generator()
        self.disc = Discriminator()

    def forward(self, noise):
        img = self.gen(noise)
        pred = self.disc(img)

        return img, pred

测试代码

from transgan_pytorch.transgan_pytorch import TransGAN
            z_dim=100,
            output_gim=32*32
        )
posted @ 2021-10-21 21:10  梁君牧  阅读(272)  评论(2编辑  收藏  举报