Mnist手写数字自编码+分类实验

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import matplotlib.pyplot as plt
import torchvision
class AutoEncodeNet(nn.Module):

    def __init__(self):
        super(AutoEncodeNet, self).__init__()
        # 编码
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 12),
            nn.Tanh(),
            nn.Linear(12, 3),   # 压缩成3个特征, 进行 3D 图像可视化
        )
         # 解压
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.Tanh(),
            nn.Linear(12, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 28*28),
            nn.Sigmoid(),       # 激励函数让输出值在 (0, 1)
        )
        # 分类器
        self.classfier = nn.Sequential(
            nn.Linear(3,128),
            nn.Tanh(),
            nn.Linear(128,10),
            nn.Sigmoid(),
        )
        

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        lable = self.classfier(encoded)
        return encoded, decoded,lable

def train():
        # 超参数
    EPOCH = 20
    BATCH_SIZE = 64
    LR = 0.005
    DOWNLOAD_MNIST = False   # 下过数据的话, 就可以设置成 False
    N_TEST_IMG = 5          # 到时候显示 5张图片看效果, 如上图一

    # Mnist digits dataset
    train_data = torchvision.datasets.MNIST(
        root='./mnist/',
        train=True,                                     # this is training data
        transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                        # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
        download=DOWNLOAD_MNIST,                        # download it if you don't have it
    )
    autoencoder = AutoEncodeNet()
    # autoencoder = torch.load("autoencoder_115.pkl")
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
    # 编码损失函数
    loss_func = nn.MSELoss()
    # 分类损失函数
    loss_func1 = nn.CrossEntropyLoss()
    # 数据加载
    train_loader = torch.utils.data.DataLoader(train_data,batch_size=128,shuffle=True)
    losses =[]
     


    fig,ax=plt.subplots(2,N_TEST_IMG)
    plt.ion()   # continuously plot
    # 会出验证的五张原图
    testImg = train_data.data[:5].view(-1,28,28).type(torch.FloatTensor)/255.
    for i in range(5):
        ax[0][i].imshow(testImg[i])

    for epoch in range(EPOCH):
        for step, (x, b_label) in enumerate(train_loader):
            b_x = x.view(-1, 28*28)   # batch x, shape (batch, 28*28)
            b_y = x.view(-1, 28*28)   # batch y, shape (batch, 28*28)

            encoded, decoded ,lable= autoencoder(b_x)
            
            # 求损失
            loss = loss_func(decoded, b_y)  + loss_func1(lable,b_label)    # mean and onehot square error
            optimizer.zero_grad()               # clear gradients for this training step
            loss.backward()                     # backpropagation, compute gradients
            optimizer.step()                    # apply gradients
            # losses.append(loss.data.numpy())

            # plt.cla()
            
            # index=random.randint(100,110)
            # print(train_data.__getitem__(index)[0].view(-1,28*28).shape)
            
            en,de,ll=autoencoder.forward(testImg.view(-1,28*28))
            dded = de.view(-1,28,28)
            for i in range(N_TEST_IMG):
                ax[1][i].clear()
                ax[1][i].imshow(dded[i].data.numpy())

                lll=list(ll.data[i])
                # print(lll)
                
                print(lll.index(max(lll)),end=" , ")
            print("------")
        
            plt.draw()
            plt.pause(0.01)
        
        print(loss)
        torch.save(autoencoder,"autoencoder_3"+epoch.__str__()+".pkl")
    plt.show()
    plt.ioff()
        
    
    
def test():
    # Mnist digits dataset
    DOWNLOAD_MNIST = False
    train_data = torchvision.datasets.MNIST(
        root='./mnist/',
        train=True,                                     # this is training data
        transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                        # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
        download=DOWNLOAD_MNIST,                        # download it if you don't have it
    )
    
    net = torch.load("autoencoder_8.pkl")
    index=random.randint(0,60000)
    print(train_data.__getitem__(index)[0].view(-1,28*28).shape)
    en,de=net.forward(train_data.__getitem__(index)[0].view(-1,28*28))
    fig,[ax1,ax2]=plt.subplots(1,2)
    # ax2=plt.subplots(1,2)
    print(en)
    ax1.imshow(train_data.__getitem__(index)[0].view(28,28).data.numpy())
    ax2.imshow(de.view(28,28).data.numpy())
    plt.show()
if __name__ == "__main__":
    
    # test()
    train()

 

posted @ 2019-12-04 17:16  dyigstraw  阅读(383)  评论(0编辑  收藏  举报
foot