Conditional AutoEncoder的Pytorch完全实现
一个完整的深度学习流程,必须包含的部分有:参数配置、Dataset和DataLoader构建、模型与optimizer与Loss函数创建、训练、验证、保存模型,以及读取模型、测试集验证等,对于生成模型来说,还应该有重构测试、生成测试。
AutoEncoder进能够重构见过的数据、VAE可以通过采样生成新数据,对于MNIST数据集来说都可以通过全连接神经网络训练。但是我们需要用CNN来实现呢,也很轻易。
Conditional VAE则有些特殊,它要把数据标签转换成One-Hot格式再拼接到数据上,MNIST数据集尚可,数据拉开也就784维度,那么对于一般的图像数据来说就不可行了。
解决这个问题只需要CNN后面接MLP(多层感知机(就是全连接网络))就行了,对于CNN应有准确的认识,实际上,CNN从通道上看,和全连接神经网络一样,都是全连接的,如果把某一层的卷积核个数(通道数)视做全连接网络某一层的节点数,那么两者结构上是一样的,不同的是卷积核还要在图像上做复杂运算,每一通道的数据是2维的,全连接网络每一个“通道”是1个标量,因此通常CNN的数据是四个维度(bs, c, h, w),而全连接网络的数据一般两个维度(bs, dim)。
CNN的特点是,随着层数加深,图像尺寸越来越小,通道数越来越高,此时就和全连接神经网络非常像了,对每一个通道做一个全局平均池化(GAP)或者最大池化(MaxPooling),那不就变成全连接网络的输入了?
实现如下(Condition AE,不是Conditional VAE)
编写Encoder如下:
import torch import torch.nn as nn class Encoder(nn.Module): def __init__(self, encoded_space_dim, fc2_input_dim=128, iscond=False, cond_dim=10): super().__init__() self.encoder_cnn = nn.Sequential( nn.Conv2d(1, 8, 3, stride=2, padding=1), nn.ReLU(True), nn.Conv2d(8, 16, 3, stride=2, padding=1), nn.BatchNorm2d(16), nn.ReLU(True), nn.Conv2d(16, 32, 3, stride=2, padding=0), nn.ReLU(True) ) self.flatten = nn.Flatten(start_dim=1) if iscond: self.encoder_lin = nn.Sequential( nn.Linear(3 * 3 * 32 + cond_dim, fc2_input_dim), nn.ReLU(True), nn.Linear(128, encoded_space_dim) ) else: self.encoder_lin = nn.Sequential( nn.Linear(3 * 3 * 32, fc2_input_dim), nn.ReLU(True), nn.Linear(128, encoded_space_dim) ) self.iscond=iscond def forward(self, x, cond_vec=None): x = self.encoder_cnn(x) x = self.flatten(x) if self.iscond: x = self.encoder_lin(torch.cat([x, cond_vec], dim=1)) else: x = self.encoder_lin(x) return x
Decoder用转置卷积,如下:
class Decoder(nn.Module): def __init__(self, encoded_space_dim, fc2_input_dim=128, iscond=False, cond_dim=10): super().__init__() if iscond: self.decoder_lin = nn.Sequential( nn.Linear(encoded_space_dim+cond_dim, fc2_input_dim), nn.ReLU(True), nn.Linear(128, 3 * 3 * 32), nn.ReLU(True) ) else: self.decoder_lin = nn.Sequential( nn.Linear(encoded_space_dim, fc2_input_dim), nn.ReLU(True), nn.Linear(128, 3 * 3 * 32), nn.ReLU(True) ) self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 3, 3)) self.decoder_conv = nn.Sequential( nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0), nn.BatchNorm2d(16), nn.ReLU(True), nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(8), nn.ReLU(True), nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1) ) self.iscond = iscond def forward(self, x, cond_vec=None): if self.iscond: x = self.decoder_lin(torch.cat([x, cond_vec], dim=1)) else: x = self.decoder_lin(x) x = self.unflatten(x) x = self.decoder_conv(x) x = torch.sigmoid(x) return x
训练直接写一个Trainer,初始化、数据集、模型初始化部分如下:
class TrainerMNIST(object): def __init__(self, istrain=False): self.istrain = istrain self.configs = { "lr": 0.001, "weight_decay": 1e-5, "batch_size": 256, "d": 4, "fc_input_dim": 128, "seed": 3407, "epochs": 12, "iscond": True, "cond_dim": 10, } self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.timestr = time.strftime("%Y%m%d%H%M", time.localtime()) # self.in_channels = 1 # self.demo_test = demo_test self.run_save_dir = "../run/aeconv_mnist/" + self.timestr + 'conditional_mix/' if istrain: # if not self.demo_test: os.makedirs(self.run_save_dir, exist_ok=True) self.model_e, self.model_d = None, None self.optim, self.loss_fn = None, None self.train_loader, self.valid_loader = None, None self.test_dataset, self.test_loader = None, None self.mix_rate = int(0.667 * self.configs["batch_size"]) def __setup_dataset(self): data_dir = 'F:/DATAS/mnist' if self.istrain: train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=True) train_transform = transforms.Compose([transforms.ToTensor(), ]) train_dataset.transform = train_transform if self.configs["iscond"]: train_dataset.target_transform = onehot(self.configs["cond_dim"]) m = len(train_dataset) train_data, val_data = random_split(train_dataset, [int(m - m * 0.2), int(m * 0.2)]) self.train_loader = torch.utils.data.DataLoader(train_data, batch_size=self.configs["batch_size"]) self.valid_loader = torch.utils.data.DataLoader(val_data, batch_size=self.configs["batch_size"]) self.test_dataset = torchvision.datasets.MNIST(data_dir, train=False, download=True) test_transform = transforms.Compose([transforms.ToTensor(), ]) self.test_dataset.transform = test_transform if self.configs["iscond"]: self.test_dataset.target_transform = onehot(self.configs["cond_dim"]) self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.configs["batch_size"], shuffle=True) print("Built Dataset and DataLoader") def setup_models(self): self.model_e = Encoder(encoded_space_dim=self.configs["d"], fc2_input_dim=self.configs["fc_input_dim"], iscond=self.configs["iscond"], cond_dim=self.configs["cond_dim"]) self.model_d = Decoder(encoded_space_dim=self.configs["d"], fc2_input_dim=self.configs["fc_input_dim"], iscond=self.configs["iscond"], cond_dim=self.configs["cond_dim"]) self.model_e.to(self.device) self.model_d.to(self.device) self.loss_fn = torch.nn.MSELoss() if self.istrain: paras_to_optimize = [ {"params": self.model_e.parameters()}, {"params": self.model_d.parameters()} ] self.optim = torch.optim.Adam(paras_to_optimize, lr=self.configs["lr"], weight_decay=self.configs["weight_decay"]) print("Built Model and Optimizer and Loss Function")
然后编写训练方法和测试方法, 以及绘图:
def train_usl(self): """ 无监督,纯AutoEncoder""" utils.setup_seed(3407) self.__setup_dataset() self.setup_models() diz_loss = {"train_loss":[], "val_loss":[]} for epoch in range(self.configs["epochs"]): self.model_e.train() self.model_d.train() train_loss = [] print(f"Train Epoch {epoch}") for i, (x, y) in enumerate(self.train_loader): x = x.to(self.device) if self.configs["iscond"]: y = y.to(self.device) encoded_data = self.model_e(x, y) decoded_data = self.model_d(encoded_data, y) loss_value = self.loss_fn(x, decoded_data) else: encoded_data = self.model_e(x) decoded_data = self.model_d(encoded_data) loss_value = self.loss_fn(x, decoded_data) self.optim.zero_grad() loss_value.backward() self.optim.step() if i % 15 == 0: print(f"\t partial train loss (single batch: {loss_value.data:.6f})") train_loss.append(loss_value.detach().cpu().numpy()) train_loss_value = np.mean(train_loss) self.model_e.eval() self.model_d.eval() val_loss = 0. with torch.no_grad(): conc_out = [] conc_label = [] for x, y in self.valid_loader: x = x.to(self.device) if self.configs["iscond"]: y = y.to(self.device) encoded_data = self.model_e(x, y) decoded_data = self.model_d(encoded_data, y) else: encoded_data = self.model_e(x) decoded_data = self.model_d(encoded_data) conc_out.append(decoded_data.cpu()) conc_label.append(x.cpu()) conc_out = torch.cat(conc_out) conc_label = torch.cat(conc_label) val_loss = self.loss_fn(conc_out, conc_label) val_loss_value = val_loss.data print(f"\t Epoch {epoch} test loss: {val_loss.item()}") diz_loss["train_loss"].append(train_loss_value) diz_loss["val_loss"].append(val_loss_value) torch.save(self.model_e.state_dict(), self.run_save_dir + '{}_epoch_{}.pth'.format("aeconve_cond", epoch)) torch.save(self.model_d.state_dict(), self.run_save_dir + '{}_epoch_{}.pth'.format("aeconvd_cond", epoch)) self.plot_ae_outputs(epoch_id=epoch) plt.figure(figsize=(10, 8)) plt.semilogy(diz_loss["train_loss"], label="Train") plt.semilogy(diz_loss["val_loss"], label="Valid") plt.xlabel("Epoch") plt.ylabel("Average Loss") plt.legend() plt.savefig(self.run_save_dir+"LossIter.png", format="png", dpi=300) def plot_ae_outputs(self, epoch_id): plt.figure() for i in range(5): ax = plt.subplot(2, 5, i+1) img = self.test_dataset[i][0].unsqueeze(0).to(self.device) self.model_e.eval() self.model_d.eval() with torch.no_grad(): if self.configs["iscond"]: y = self.test_dataset[i][1].unsqueeze(0).to(self.device) rec_img = self.model_d(self.model_e(img, y), y) else: rec_img = self.model_d(self.model_e(img)) plt.imshow(img.cpu().squeeze().numpy(), cmap="gist_gray") ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) if i == 5//2: ax.set_title("Original images") ax = plt.subplot(2, 5, i+1+5) plt.imshow(rec_img.cpu().squeeze().numpy(), cmap="gist_gray") ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) if i == 5//2: ax.set_title("Reconstructed images") plt.savefig(self.run_save_dir+f"test_plot_epoch_{epoch_id}.png", format="png", dpi=300)
main函数里运行如下:
if __name__ == '__main__': trainer = TrainerMNIST(istrain=True) # trainer.train_usl()
训练完毕后,在运行保存的目录里面就生成了模型文件、重构图像、损失函数迭代曲线图像。
再编写一个测试生成图像的函数:
def recon_test_one(self, resume_path): if (self.model_e is None) or (self.model_d is None): self.setup_models() self.model_e.eval() self.model_d.eval() state_dict_e = torch.load(os.path.join(resume_path, f'aeconve_cond_epoch_11.pth')) self.model_e.load_state_dict(state_dict_e) state_dict_d = torch.load(os.path.join(resume_path, f'aeconvd_cond_epoch_11.pth')) self.model_d.load_state_dict(state_dict_d) print(self.model_d) z = torch.randn(size=(10, 4, ), device=self.device) y_label = torch.zeros(size=(10, 10)) labels = torch.arange(0, 10).unsqueeze(1) # print(labels) y_label.scatter_(1, labels, 1) y_label = y_label.to(self.device) recon_images = self.model_d(z, y_label) recon_images = recon_images.squeeze().detach().cpu().numpy() plt.figure(0) for i in range(10): plt.subplot(2, 5, i+1) plt.imshow(recon_images[i]) plt.xticks([]) plt.yticks([]) plt.title(f"generate_{i}") plt.savefig(resume_path+"generate_0.png", format="png", dpi=300) plt.show() if __name__ == '__main__': trainer = TrainerMNIST(istrain=True) # trainer.test_data() # trainer.train_usl() trainer.recon_test_one("../run/aeconv_mnist/202404141930conditional_mix/")
测试结果发现其实效果并不好,个人觉得这是Conditional AE自身的问题,学习到的表征并不充分,没有涵盖生成新数据的连续空间,如果是Conditional VAE,就能做到。