pytoch之 encoder,decoder
###仅为自己练习,没有其他用途
1 import torch 2 import torch.nn as nn 3 import torch.utils.data as Data 4 import torchvision 5 import matplotlib.pyplot as plt 6 from mpl_toolkits.mplot3d import Axes3D 7 from matplotlib import cm 8 import numpy as np 9 10 11 # torch.manual_seed(1) # reproducible 12 13 # Hyper Parameters 14 EPOCH = 10 15 BATCH_SIZE = 64 16 LR = 0.005 # learning rate 17 DOWNLOAD_MNIST = False 18 N_TEST_IMG = 5 19 20 # Mnist digits dataset 21 train_data = torchvision.datasets.MNIST( 22 root='./mnist/', 23 train=True, # this is training data 24 transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to 25 # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] 26 download=DOWNLOAD_MNIST, # download it if you don't have it 27 ) 28 29 # plot one example 30 print(train_data.train_data.size()) # (60000, 28, 28) 31 print(train_data.train_labels.size()) # (60000) 32 plt.imshow(train_data.train_data[2].numpy(), cmap='gray') 33 plt.title('%i' % train_data.train_labels[2]) 34 plt.show() 35 36 # Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28) 37 train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) 38 39 40 class AutoEncoder(nn.Module): 41 def __init__(self): 42 super(AutoEncoder, self).__init__() 43 44 self.encoder = nn.Sequential( 45 nn.Linear(28*28, 128), 46 nn.Tanh(), 47 nn.Linear(128, 64), 48 nn.Tanh(), 49 nn.Linear(64, 12), 50 nn.Tanh(), 51 nn.Linear(12, 3), # compress to 3 features which can be visualized in plt 52 ) 53 self.decoder = nn.Sequential( 54 nn.Linear(3, 12), 55 nn.Tanh(), 56 nn.Linear(12, 64), 57 nn.Tanh(), 58 nn.Linear(64, 128), 59 nn.Tanh(), 60 nn.Linear(128, 28*28), 61 nn.Sigmoid(), # compress to a range (0, 1) 62 ) 63 64 def forward(self, x): 65 encoded = self.encoder(x) 66 decoded = self.decoder(encoded) 67 return encoded, decoded 68 69 70 autoencoder = AutoEncoder() 71 72 optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR) 73 loss_func = nn.MSELoss() 74 75 # initialize figure 76 f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2)) 77 plt.ion() # continuously plot 78 79 # original data (first row) for viewing 80 view_data = train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255. 81 for i in range(N_TEST_IMG): 82 a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(()) 83 84 for epoch in range(EPOCH): 85 for step, (x, b_label) in enumerate(train_loader): 86 b_x = x.view(-1, 28*28) # batch x, shape (batch, 28*28) 87 b_y = x.view(-1, 28*28) # batch y, shape (batch, 28*28) 88 89 encoded, decoded = autoencoder(b_x) 90 91 loss = loss_func(decoded, b_y) # mean square error 92 optimizer.zero_grad() # clear gradients for this training step 93 loss.backward() # backpropagation, compute gradients 94 optimizer.step() # apply gradients 95 96 if step % 100 == 0: 97 print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy()) 98 99 # plotting decoded image (second row) 100 _, decoded_data = autoencoder(view_data) 101 for i in range(N_TEST_IMG): 102 a[1][i].clear() 103 a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray') 104 a[1][i].set_xticks(()); a[1][i].set_yticks(()) 105 plt.draw(); plt.pause(0.05) 106 107 plt.ioff() 108 plt.show() 109 110 # visualize in 3D plot 111 view_data = train_data.train_data[:200].view(-1, 28*28).type(torch.FloatTensor)/255. 112 encoded_data, _ = autoencoder(view_data) 113 fig = plt.figure(2); ax = Axes3D(fig) 114 X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy() 115 values = train_data.train_labels[:200].numpy() 116 for x, y, z, s in zip(X, Y, Z, values): 117 c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c) 118 ax.set_xlim(X.min(), X.max()); ax.set_ylim(Y.min(), Y.max()); ax.set_zlim(Z.min(), Z.max()) 119 plt.show()