Conditional AutoEncoder的Pytorch完全实现

一个完整的深度学习流程,必须包含的部分有:参数配置、Dataset和DataLoader构建、模型与optimizer与Loss函数创建、训练、验证、保存模型,以及读取模型、测试集验证等,对于生成模型来说,还应该有重构测试、生成测试。

AutoEncoder进能够重构见过的数据、VAE可以通过采样生成新数据,对于MNIST数据集来说都可以通过全连接神经网络训练。但是我们需要用CNN来实现呢,也很轻易。

Conditional VAE则有些特殊,它要把数据标签转换成One-Hot格式再拼接到数据上,MNIST数据集尚可,数据拉开也就784维度,那么对于一般的图像数据来说就不可行了。

解决这个问题只需要CNN后面接MLP(多层感知机(就是全连接网络))就行了,对于CNN应有准确的认识,实际上,CNN从通道上看,和全连接神经网络一样,都是全连接的,如果把某一层的卷积核个数(通道数)视做全连接网络某一层的节点数,那么两者结构上是一样的,不同的是卷积核还要在图像上做复杂运算,每一通道的数据是2维的,全连接网络每一个“通道”是1个标量,因此通常CNN的数据是四个维度(bs, c, h, w),而全连接网络的数据一般两个维度(bs, dim)。

CNN的特点是,随着层数加深,图像尺寸越来越小,通道数越来越高,此时就和全连接神经网络非常像了,对每一个通道做一个全局平均池化(GAP)或者最大池化(MaxPooling),那不就变成全连接网络的输入了?

实现如下(Condition AE,不是Conditional VAE)

编写Encoder如下:

import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, encoded_space_dim, fc2_input_dim=128, iscond=False, cond_dim=10):
        super().__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True)
        )

        self.flatten = nn.Flatten(start_dim=1)
        if iscond:
            self.encoder_lin = nn.Sequential(
                nn.Linear(3 * 3 * 32 + cond_dim, fc2_input_dim),
                nn.ReLU(True),
                nn.Linear(128, encoded_space_dim)
            )
        else:
            self.encoder_lin = nn.Sequential(
                nn.Linear(3 * 3 * 32, fc2_input_dim),
                nn.ReLU(True),
                nn.Linear(128, encoded_space_dim)
            )
        self.iscond=iscond

    def forward(self, x, cond_vec=None):
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        if self.iscond:
            x = self.encoder_lin(torch.cat([x, cond_vec], dim=1))
        else:
            x = self.encoder_lin(x)
        return x

Decoder用转置卷积,如下:

class Decoder(nn.Module):
    def __init__(self, encoded_space_dim, fc2_input_dim=128, iscond=False, cond_dim=10):
        super().__init__()
        if iscond:
            self.decoder_lin = nn.Sequential(
                nn.Linear(encoded_space_dim+cond_dim, fc2_input_dim),
                nn.ReLU(True),
                nn.Linear(128, 3 * 3 * 32),
                nn.ReLU(True)
            )
        else:
            self.decoder_lin = nn.Sequential(
                nn.Linear(encoded_space_dim, fc2_input_dim),
                nn.ReLU(True),
                nn.Linear(128, 3 * 3 * 32),
                nn.ReLU(True)
            )
        self.unflatten = nn.Unflatten(dim=1,
                                      unflattened_size=(32, 3, 3))
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3,
                               stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2,
                               padding=1, output_padding=1),
            nn.BatchNorm2d(8),

            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2,
                               padding=1, output_padding=1)
        )
        self.iscond = iscond

    def forward(self, x, cond_vec=None):
        if self.iscond:
            x = self.decoder_lin(torch.cat([x, cond_vec], dim=1))
        else:
            x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

训练直接写一个Trainer,初始化、数据集、模型初始化部分如下:

class TrainerMNIST(object):
    def __init__(self, istrain=False):
        self.istrain = istrain
        self.configs = {
            "lr": 0.001, "weight_decay": 1e-5, "batch_size": 256, "d": 4, "fc_input_dim": 128, "seed": 3407, 
            "epochs": 12, "iscond": True, "cond_dim": 10,
        }
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.timestr = time.strftime("%Y%m%d%H%M", time.localtime())
        # self.in_channels = 1
        # self.demo_test = demo_test
        self.run_save_dir = "../run/aeconv_mnist/" + self.timestr + 'conditional_mix/'
        if istrain:
            # if not self.demo_test:
            os.makedirs(self.run_save_dir, exist_ok=True)
        self.model_e, self.model_d = None, None
        self.optim, self.loss_fn = None, None
        self.train_loader, self.valid_loader = None, None
        self.test_dataset, self.test_loader = None, None
        self.mix_rate = int(0.667 * self.configs["batch_size"])

    def __setup_dataset(self):
        data_dir = 'F:/DATAS/mnist'
        if self.istrain:
            train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=True)
            train_transform = transforms.Compose([transforms.ToTensor(), ])
            train_dataset.transform = train_transform
            if self.configs["iscond"]:
                train_dataset.target_transform = onehot(self.configs["cond_dim"])
            m = len(train_dataset)
            train_data, val_data = random_split(train_dataset, [int(m - m * 0.2), int(m * 0.2)])
            self.train_loader = torch.utils.data.DataLoader(train_data, batch_size=self.configs["batch_size"])
            self.valid_loader = torch.utils.data.DataLoader(val_data, batch_size=self.configs["batch_size"])

        self.test_dataset = torchvision.datasets.MNIST(data_dir, train=False, download=True)
        test_transform = transforms.Compose([transforms.ToTensor(), ])
        self.test_dataset.transform = test_transform
        if self.configs["iscond"]:
            self.test_dataset.target_transform = onehot(self.configs["cond_dim"])
        self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.configs["batch_size"], shuffle=True)

        print("Built Dataset and DataLoader")

    def setup_models(self):
        self.model_e = Encoder(encoded_space_dim=self.configs["d"], fc2_input_dim=self.configs["fc_input_dim"],
                               iscond=self.configs["iscond"], cond_dim=self.configs["cond_dim"])
        self.model_d = Decoder(encoded_space_dim=self.configs["d"], fc2_input_dim=self.configs["fc_input_dim"],
                               iscond=self.configs["iscond"], cond_dim=self.configs["cond_dim"])
        self.model_e.to(self.device)
        self.model_d.to(self.device)
        self.loss_fn = torch.nn.MSELoss()
        if self.istrain:
            paras_to_optimize = [
                {"params": self.model_e.parameters()},
                {"params": self.model_d.parameters()}
            ]
            self.optim = torch.optim.Adam(paras_to_optimize, lr=self.configs["lr"], weight_decay=self.configs["weight_decay"])
        print("Built Model and Optimizer and Loss Function")

然后编写训练方法和测试方法, 以及绘图:

    def train_usl(self):
        """ 无监督,纯AutoEncoder"""
        utils.setup_seed(3407)
        self.__setup_dataset()
        self.setup_models()
        diz_loss = {"train_loss":[], "val_loss":[]}
        for epoch in range(self.configs["epochs"]):

            self.model_e.train()
            self.model_d.train()
            train_loss = []
            print(f"Train Epoch {epoch}")
            for i, (x, y) in enumerate(self.train_loader):
                x = x.to(self.device)

                if self.configs["iscond"]:
                    y = y.to(self.device)
                    encoded_data = self.model_e(x, y)
                    decoded_data = self.model_d(encoded_data, y)
                    loss_value = self.loss_fn(x, decoded_data)
                else:
                    encoded_data = self.model_e(x)
                    decoded_data = self.model_d(encoded_data)
                    loss_value = self.loss_fn(x, decoded_data)

                self.optim.zero_grad()
                loss_value.backward()
                self.optim.step()
                if i % 15 == 0:
                    print(f"\t partial train loss (single batch: {loss_value.data:.6f})")
                train_loss.append(loss_value.detach().cpu().numpy())
            train_loss_value = np.mean(train_loss)

            self.model_e.eval()
            self.model_d.eval()
            val_loss = 0.
            with torch.no_grad():
                conc_out = []
                conc_label = []
                for x, y in self.valid_loader:
                    x = x.to(self.device)
                    if self.configs["iscond"]:
                        y = y.to(self.device)
                        encoded_data = self.model_e(x, y)
                        decoded_data = self.model_d(encoded_data, y)
                    else:
                        encoded_data = self.model_e(x)
                        decoded_data = self.model_d(encoded_data)
                    conc_out.append(decoded_data.cpu())
                    conc_label.append(x.cpu())
                conc_out = torch.cat(conc_out)
                conc_label = torch.cat(conc_label)
                val_loss = self.loss_fn(conc_out, conc_label)
            val_loss_value = val_loss.data
            print(f"\t Epoch {epoch} test loss: {val_loss.item()}")
            diz_loss["train_loss"].append(train_loss_value)
            diz_loss["val_loss"].append(val_loss_value)
            torch.save(self.model_e.state_dict(), self.run_save_dir + '{}_epoch_{}.pth'.format("aeconve_cond", epoch))
            torch.save(self.model_d.state_dict(), self.run_save_dir + '{}_epoch_{}.pth'.format("aeconvd_cond", epoch))
            self.plot_ae_outputs(epoch_id=epoch)
        plt.figure(figsize=(10, 8))
        plt.semilogy(diz_loss["train_loss"], label="Train")
        plt.semilogy(diz_loss["val_loss"], label="Valid")
        plt.xlabel("Epoch")
        plt.ylabel("Average Loss")
        plt.legend()
        plt.savefig(self.run_save_dir+"LossIter.png", format="png", dpi=300)

    def plot_ae_outputs(self, epoch_id):
        plt.figure()
        for i in range(5):
            ax = plt.subplot(2, 5, i+1)
            img = self.test_dataset[i][0].unsqueeze(0).to(self.device)
            self.model_e.eval()
            self.model_d.eval()
            with torch.no_grad():
                if self.configs["iscond"]:
                    y = self.test_dataset[i][1].unsqueeze(0).to(self.device)
                    rec_img = self.model_d(self.model_e(img, y), y)
                else:
                    rec_img = self.model_d(self.model_e(img))
            plt.imshow(img.cpu().squeeze().numpy(), cmap="gist_gray")
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            if i == 5//2:
                ax.set_title("Original images")
            ax = plt.subplot(2, 5, i+1+5)
            plt.imshow(rec_img.cpu().squeeze().numpy(), cmap="gist_gray")
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            if i == 5//2:
                ax.set_title("Reconstructed images")
        plt.savefig(self.run_save_dir+f"test_plot_epoch_{epoch_id}.png", format="png", dpi=300)

main函数里运行如下:

if __name__ == '__main__':
    trainer = TrainerMNIST(istrain=True)
    # trainer.train_usl()

训练完毕后,在运行保存的目录里面就生成了模型文件、重构图像、损失函数迭代曲线图像。

再编写一个测试生成图像的函数:

    def recon_test_one(self, resume_path):
        if (self.model_e is None) or (self.model_d is None):
            self.setup_models()
        self.model_e.eval()
        self.model_d.eval()
        state_dict_e = torch.load(os.path.join(resume_path, f'aeconve_cond_epoch_11.pth'))
        self.model_e.load_state_dict(state_dict_e)
        state_dict_d = torch.load(os.path.join(resume_path, f'aeconvd_cond_epoch_11.pth'))
        self.model_d.load_state_dict(state_dict_d)
        print(self.model_d)
        z = torch.randn(size=(10, 4, ), device=self.device)
        y_label = torch.zeros(size=(10, 10))
        labels = torch.arange(0, 10).unsqueeze(1)
        # print(labels)
        y_label.scatter_(1, labels, 1)
        y_label = y_label.to(self.device)
        recon_images = self.model_d(z, y_label)
        recon_images = recon_images.squeeze().detach().cpu().numpy()
        plt.figure(0)
        for i in range(10):
            plt.subplot(2, 5, i+1)
            plt.imshow(recon_images[i])
            plt.xticks([])
            plt.yticks([])
            plt.title(f"generate_{i}")
        plt.savefig(resume_path+"generate_0.png", format="png", dpi=300)
        plt.show()


if __name__ == '__main__':
    trainer = TrainerMNIST(istrain=True)
    # trainer.test_data()
    # trainer.train_usl()
    trainer.recon_test_one("../run/aeconv_mnist/202404141930conditional_mix/")

测试结果发现其实效果并不好,个人觉得这是Conditional AE自身的问题,学习到的表征并不充分,没有涵盖生成新数据的连续空间,如果是Conditional VAE,就能做到。

posted @ 2024-04-14 20:17  倦鸟已归时  阅读(65)  评论(0编辑  收藏  举报