06_CycleGAN

cycle GAN主要用于图像之间的转换,如图像风格转换.

 

 

 Cycle GAN原理

  CycleGAN可以完成从一个模式到另外一个模式的转换,转换,比如从男人到女人:

 

 

 CycleGAN适用于非配对的图像到图像转换,CycleGAN解决了模型需要成对数据进行训练的困难。

 

 

 CycleGAN的论文地址 https://arxiv.org/pdf/1703.10593.pdf

 CycleGAN的原理可以概述为:将一类图片转换成另一类图片 。也就是说,现在有两个样本空间, X和Y,我们希望把X空间中的样本转换成Y空间中的样本。(获取一个数据集的特征,并转化成另一个数据集的特征)。

这样来看: 实际的目标就是学习从X到Y的映射。我们假设这个映射为F。它就对应着GAN中的 生成器 , F可以将X中的图片x转换为Y中的图片F(x)。对于生成的图片,我们还需要GAN中的 判别器 来判别它是否为真实图片,由此构成对抗生成网络。

Cycle GAN 模型及理论

 

 

 

   从理论上讲,对抗训练可以学习和产生与目标域Y相同分布的输出,但会产生一些问题。在足够大的样本容量下,网络可以将相同的输入图像集合映射到目标域中图像的任何随机排列,其中任何学习的映射可以归纳出与目标分布匹配的输出分布(即:映射F完全可以将所有x都映射为Y空间中的同一张图片,使损失无效化) 。因此,单独的对抗损失Loss不能保证学习函数可以将单个输入Xi映射到期望的输出Yi 。对此,作者又提出了所谓的“循环一致性损失” (cycle consistency loss).
  我们希望能够把 domain A 的图片(命名为 a)转化为 domain B 的图片(命名为图片 b)。为了实现这个过程,我们需要两个生成器 G_AB 和G_BA,分别把 domain A 和 domain B 的图片进行互相转换;将X的图片转换到Y空间后,应该还可以转换回来。这样就杜绝模型把所有X的图片都转换为Y空间中的同一张图片了。
  最后为了训练这个单向 GAN 需要两个 loss 分别是: 生成器的重建loss,判别器的判别loss。

生成器损失和判别损失
  判别 loss: 判别器 D_B 是用来判断输入的图片是否是真实的domain B图片。

 

 

   生成 loss:生成器用来重建图片 a,目的是希望生成的图片 G_BA(G_AB(a)) 和原图 a 尽可能的相似,那么可以很简单的采取 L1 loss 或者 L2 loss。最后生成 loss 就表示为:

 

 

   CycleGAN 其实就是一个 A→B 单向 GAN 加上一个B→A 单向 GAN。两个 GAN 共享两个生成器,然后各自带一个判别器,所以加起来总共有两个判别器和两个生成器。 一个单向 GAN 有两个 loss,而 CycleGAN 加起来总共有四个 loss。


循环一致性损失

 

 

   GAN网络的对抗loss之外,还有一个cycle-loss,也就是循环一致损失。因为网络需要保证生成的图像必须保留有原始图像的特性,所以如果我们使用生成器GenratorA-B生成一张假图像,那么要能够使用另外一个生成器GenratorB-A来努力恢复成原始图像。此过程必须满足循环一致性。

Cycle GAN损失

  循环一致损失:

 

 

   总的损失:

 

 

   在编写代码时还有个loss: identity loss,可以理解为,生成器是负责域x到域y的图像生成,如果输入域y的图片还是应该生成域y的图片y",计算y‘’ 和输入 y 的loss

Cycle GAN训练

  Generator采用的是Perceptual losses for real-time style transfer and super-resolution 一文中的网络结构 (论文地址:https://arxiv.org/abs/1603.08155)。

  一个resblock组成的网络,下采样部分采用stride 卷积,上采样部分采用反卷积
  Discriminator采用的仍是pix2pix中的PatchGAN结构,大小为70x70;

  定义四个损失函数,分别优化训练G和D,两个生成器共享权重,两个鉴别器也共享权重训练。

  Lr=0.0002。对于前100个周期,保持相同的学习速率0.0002,然后在接下来的100个周期内线性衰减到0。
Cycle GAN局限性

  对颜色、纹理等的转换效果比较好, 但是:

  1.)会在改变物体的同时改变背景。

 

   2.)缺少多样性

  Source domain 和 target domain 的维度应该是不一样的,比如笑和不笑,笑自然是闭着嘴,但是不笑的程度多种多样,可能有微笑,哈哈大笑等等。又比如无眼镜就是一种,有眼镜可以是各式各样的眼镜。

  3.)在需要几何变化的任务上表现一般。

 

 

Cycle GAN的代码实现(pytorch版本)

  1 import torch
  2 import torch.nn as nn
  3 import torch.nn.functional as F
  4 from torch.utils import data
  5 import torchvision
  6 from torchvision import transforms
  7 
  8 import numpy as np
  9 import matplotlib.pyplot as plt
 10 import os
 11 import glob
 12 from PIL import Image
 13 import itertools
 14 
 15 import time
 16 
 17 dataNames = os.listdir('apple2orange/')
 18 print(dataNames)
 19 
 20 def plotImg(data_path, name):
 21 
 22 
 23     plt.figure(figsize=(12, 8))
 24     for i, img_path in enumerate(data_path[:4]):
 25         img = Image.open(img_path)
 26         np_img = np.array(img)
 27         plt.subplot(2,2,i+1)
 28         plt.imshow(np_img)
 29         plt.title(str(np_img.shape))
 30 
 31     plt.show()
 32     plt.savefig('apple2orange/result/image_{:s}.png'.format(name))
 33     plt.close()
 34     
 35 
 36 apples_path = glob.glob('apple2orange/testA/*.jpg')
 37 oranges_path = glob.glob('apple2orange/testB/*.jpg')
 38 
 39 # print(len(apples_path), len(oranges_path))
 40 # print(apples_path[:3], oranges_path[:3])
 41 
 42 # plotImg(apples_path, "apple")
 43 # plotImg(oranges_path, "oranges")
 44 
 45 
 46 transform = transforms.Compose([
 47     transforms.ToTensor(),                        # 取值范围会被归一化到(0, 1)之间
 48     transforms.Normalize(mean=0.5, std=0.5)       # 设置均值和方差均为0.5
 49 ])
 50 
 51 class Apple2orange_Dataset(data.Dataset):
 52     def __init__(self, imgs_path):
 53         self.imgs_path = imgs_path
 54 
 55     def __getitem__(self, index):
 56         img_path = self.imgs_path[index]
 57         pil_img = Image.open(img_path)
 58         pil_img = transform(pil_img)
 59         return pil_img
 60     
 61     def __len__(self):
 62         return len(self.imgs_path)
 63 
 64 apples_dataset = Apple2orange_Dataset(apples_path)
 65 oranges_dataset = Apple2orange_Dataset(oranges_path)
 66 print(len(apples_dataset), len(oranges_dataset))    # 打印数据集大小 
 67 
 68 
 69 BTACH_SIZE = 32                                   # 批次大小
 70 apples_dl = torch.utils.data.DataLoader(
 71                                        apples_dataset,
 72                                        batch_size=BTACH_SIZE,
 73                                        shuffle=True
 74 )
 75 
 76 oranges_dl = torch.utils.data.DataLoader(
 77                                        oranges_dataset,
 78                                        batch_size=BTACH_SIZE,
 79                                        shuffle=True
 80 )
 81 
 82 apples_batch = next(iter(apples_dl))    # 返回一个批次的训练数据
 83 oranges_batch = next(iter(oranges_dl))    # 返回一个批次的训练数据
 84 
 85 # 绘制批次中前3对图片  
 86 def pltImgAndMask(imgData, maskData, imgName, maskName):
 87 
 88     # 绘制批次中前3对图片 
 89     plt.figure(figsize=(12, 18))
 90 
 91     for i, (img, mask) in enumerate(zip(imgData[:3], maskData[:3])):
 92         
 93         # 设置channel最后,并还原到取值0-1之间
 94         img = (img.permute(1, 2, 0).numpy() +1)/2
 95         mask = (mask.permute(1, 2, 0).numpy()+1)/2
 96         
 97         plt.subplot(3, 2, 2*i+1)
 98         plt.title('apple image')
 99         plt.imshow(img)
100         
101         plt.subplot(3, 2, 2*i+2)
102         plt.title('orange image')
103         plt.imshow(mask)
104 
105     plt.savefig('apple2orange/result/image_{:s}_{:s}.png'.format(imgName, maskName))
106     plt.close()
107 
108 imgName = "apple"
109 maskName = "orrange"
110 # pltImgAndMask(apples_batch, oranges_batch, imgName, maskName)
111 
112 
113 apples_path_test = glob.glob("apple2orange/testA/*.jpg")
114 oranges_path_test = glob.glob("apple2orange/testB/*.jpg")
115 
116 apples_dataset_test = Apple2orange_Dataset(apples_path_test)
117 oranges_dataset_test = Apple2orange_Dataset(oranges_path_test)
118 
119 apples_dl_test = torch.utils.data.DataLoader(
120                                        apples_dataset_test,
121                                        batch_size=BTACH_SIZE,
122                                        shuffle=True
123 )
124 
125 oranges_dl_test = torch.utils.data.DataLoader(
126                                        oranges_dataset_test,
127                                        batch_size=BTACH_SIZE,
128                                        shuffle=True
129 )
130 
131 class Downsample(nn.Module):
132     def __init__(self, in_channels, out_channels):
133         super(Downsample, self).__init__()
134         self.conv_relu = nn.Sequential(
135                             nn.Conv2d(in_channels, out_channels, 
136                                       kernel_size=3, stride=2, padding=1),
137                             nn.LeakyReLU(inplace=True),
138             )
139         self.bn = nn.InstanceNorm2d(out_channels)
140     def forward(self, x, is_bn=True):
141         x = self.conv_relu(x)
142         if is_bn:
143             x = self.bn(x)
144         return x
145     
146 class Upsample(nn.Module):
147     def __init__(self, in_channels, out_channels):
148         super(Upsample, self).__init__()
149         self.upconv_relu = nn.Sequential(
150                                nn.ConvTranspose2d(in_channels, out_channels, 
151                                                   kernel_size=3,
152                                                   stride=2,
153                                                   padding=1,
154                                                   output_padding=1),
155                                nn.LeakyReLU(inplace=True)
156             )
157         self.bn = nn.InstanceNorm2d(out_channels)
158         
159     def forward(self, x, is_drop=False):
160         x = self.upconv_relu(x)
161         x = self.bn(x)
162         if is_drop:
163             x = F.dropout2d(x)
164         return x
165     
166 class Generator(nn.Module):
167     def __init__(self):
168         super(Generator, self).__init__()
169         self.down1 = Downsample(3, 64)
170         self.down2 = Downsample(64, 128)
171         self.down3 = Downsample(128, 256)
172         self.down4 = Downsample(256, 512)
173         self.down5 = Downsample(512, 512)
174         self.down6 = Downsample(512, 512)
175         
176         self.up1 = Upsample(512, 512)
177         self.up2 = Upsample(1024, 512)
178         self.up3 = Upsample(1024, 256)
179         self.up4 = Upsample(512, 128)
180         self.up5 = Upsample(256, 64)
181 
182         self.last = nn.ConvTranspose2d(128, 3, 
183                                        kernel_size=3,
184                                        stride=2,
185                                        padding=1,
186                                        output_padding=1)
187         
188 
189     def forward(self, x):
190         x1 = self.down1(x, is_bn=False)   # torch.Size([8, 64, 128, 128])
191         x2 = self.down2(x1)               # torch.Size([8, 128, 64, 64])
192         x3 = self.down3(x2)               # torch.Size([8, 256, 32, 32])
193         x4 = self.down4(x3)               # torch.Size([8, 512, 16, 16])
194         x5 = self.down5(x4)               # torch.Size([8, 512, 8, 8])
195         x6 = self.down6(x5)               # torch.Size([8, 512, 4, 4])
196         
197         x6 = self.up1(x6, is_drop=True)   # torch.Size([8, 512, 8, 8])
198         x6 = torch.cat([x5, x6], dim=1)   # torch.Size([8, 1024, 8, 8])
199         
200         x6 = self.up2(x6, is_drop=True)   # torch.Size([8, 512, 16, 16])
201         x6 = torch.cat([x4, x6], dim=1)   # torch.Size([8, 1024, 16, 16])
202         
203         x6 = self.up3(x6, is_drop=True)                         
204         x6 = torch.cat([x3, x6], dim=1)          
205         
206         x6 = self.up4(x6)                       
207         x6 = torch.cat([x2, x6], dim=1)         
208         
209         x6 = self.up5(x6)                        
210         x6 = torch.cat([x1, x6], dim=1)         
211         
212         x6 = torch.tanh(self.last(x6))           
213         return x6
214 
215 # 定义判别器
216 class Discriminator(nn.Module):
217     def __init__(self):
218         super(Discriminator, self).__init__()
219         self.down1 = Downsample(3, 64)             # 128
220         self.down2 = Downsample(64, 128)           # 64
221         self.last = nn.Conv2d(128, 1, 3)
222 
223     def forward(self, img):
224         x = self.down1(img)
225         x = self.down2(x)
226         x = torch.sigmoid(self.last(x))
227         return x
228     
229 # 定义判别器
230 class Discriminator(nn.Module):
231     def __init__(self):
232         super(Discriminator, self).__init__()
233         self.down1 = Downsample(3, 64)             # 128
234         self.down2 = Downsample(64, 128)           # 64
235         self.last = nn.Conv2d(128, 1, 3)
236 
237     def forward(self, img):
238         x = self.down1(img)
239         x = self.down2(x)
240         x = torch.sigmoid(self.last(x))
241         return x
242     
243 device = "cuda:1" if torch.cuda.is_available() else "cpu"
244 gen_AB = Generator().to(device)
245 gen_BA = Generator().to(device)
246 dis_A = Discriminator().to(device)
247 dis_B = Discriminator().to(device)
248 
249 
250 bceloss_fn = torch.nn.BCELoss()                   # 定义损失函数
251 l1loss_fn = torch.nn.L1Loss()
252 
253 gen_optimizer = torch.optim.Adam(itertools.chain(gen_AB.parameters(), gen_BA.parameters()), 
254                                  lr=2e-4, betas=(0.5, 0.999))
255 dis_A_optimizer = torch.optim.Adam(dis_A.parameters(), lr=2e-4, betas=(0.5, 0.999))
256 dis_B_optimizer = torch.optim.Adam(dis_B.parameters(), lr=2e-4, betas=(0.5, 0.999)) 
257     
258 # 绘制测试结果图像
259 def generate_images(model, test_input, epoch):
260     prediction = model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy()
261     test_input = test_input.permute(0, 2, 3, 1).cpu().numpy()
262     
263     plt.figure(figsize=(100, 50))
264     title = ['Input Image', 'Predicted Image']
265     for i in range(4):
266         plt.subplot(2, 4, i+1)
267         plt.title(title[0])
268         plt.imshow(test_input[i] * 0.5 + 0.5)
269         plt.axis('off')
270     for i in range(4):    
271         plt.subplot(2, 4, i+5)
272         plt.title(title[1])
273         plt.imshow(prediction[i] * 0.5 + 0.5)
274         plt.axis('off')
275     
276     plt.savefig('apple2orange/result/image_at_epoch_{:04d}.png'.format(epoch))
277     
278     
279 test_batch = next(iter(apples_dl_test))
280 # test_input = torch.unsqueeze(test_batch[0], 0).to(device)
281 test_input = test_batch[:4].to(device)
282 
283 
284 D_loss = []                          # 记录训练过程中判别器loss变化
285 G_loss = []                          # 记录训练过程中生成器loss变化
286 epochs = []
287 
288 #开始训练
289 for epoch in range(300):
290     
291     epoch_start = time.time()
292     
293     D_epoch_loss=0
294     G_epoch_loss=0
295     for step, (real_A, real_B) in enumerate(zip(apples_dl, oranges_dl)):
296         real_A = real_A.to(device)
297         real_B = real_B.to(device)
298         
299         # GAN 训练
300         gen_optimizer.zero_grad()
301         
302         # identity loss
303         same_B = gen_AB(real_B)
304         identity_B_loss = l1loss_fn(same_B, real_B)
305         same_A = gen_BA(real_A)
306         identity_A_loss = l1loss_fn(same_A, real_A)
307         
308         # GAN loss
309         fake_B = gen_AB(real_A)
310         D_pred_fake_B = dis_B(fake_B)
311         gan_loss_AB = bceloss_fn(D_pred_fake_B, 
312                                 torch.ones_like(D_pred_fake_B, device=device))
313         
314         fake_A = gen_BA(real_B)
315         D_pred_fake_A = dis_A(fake_A)
316         gan_loss_BA = bceloss_fn(D_pred_fake_A, 
317                                 torch.ones_like(D_pred_fake_A, device=device))
318         
319         # cycle consistanse loss
320         recovered_A = gen_BA(fake_B)
321         cycle_loss_ABA = l1loss_fn(recovered_A, real_A)
322         
323         recovered_B = gen_AB(fake_A)
324         cycle_loss_BAB = l1loss_fn(recovered_B, real_B)
325         
326         # total_loss
327         g_loss = (identity_B_loss + identity_A_loss + gan_loss_AB + gan_loss_BA
328                  + cycle_loss_ABA + cycle_loss_BAB)
329         
330         g_loss.backward()
331         gen_optimizer.step()
332         
333         # dis_A 训练
334         dis_A_optimizer.zero_grad()
335         dis_A_real_output = dis_A(real_A)              # 判别器输入真实图片
336         dis_A_real_loss = bceloss_fn(dis_A_real_output, 
337                                      torch.ones_like(dis_A_real_output, device=device))
338         
339         dis_A_fake_output = dis_A(fake_A.detach())              # 判别器输入生成图片
340         dis_A_fake_loss = bceloss_fn(dis_A_fake_output, 
341                                      torch.zeros_like(dis_A_fake_output, device=device))
342         
343         dis_A_loss = (dis_A_real_loss + dis_A_fake_loss)*0.5
344         
345         dis_A_loss.backward()
346         dis_A_optimizer.step()
347         
348         
349         # dis_B 训练
350         dis_B_optimizer.zero_grad()
351         dis_B_real_output = dis_B(real_B)              # 判别器输入真实图片
352         dis_B_real_loss = bceloss_fn(dis_B_real_output, 
353                                      torch.ones_like(dis_B_real_output, device=device))
354         
355         dis_B_fake_output = dis_B(fake_B.detach())              # 判别器输入生成图片
356         dis_B_fake_loss = bceloss_fn(dis_B_fake_output, 
357                                      torch.zeros_like(dis_B_fake_output, device=device))
358         
359         dis_B_loss = (dis_B_real_loss + dis_B_fake_loss)*0.5
360         
361         dis_B_loss.backward()
362         dis_B_optimizer.step()
363     
364         # 打印 loss 变化
365         with torch.no_grad():
366             D_epoch_loss += (dis_A_loss + dis_B_loss).item()
367             G_epoch_loss += g_loss.item()
368     
369     epoch_finish = time.time()
370      
371     with torch.no_grad():        
372         D_epoch_loss /= step
373         G_epoch_loss /= step
374         D_loss.append(D_epoch_loss)
375         G_loss.append(G_epoch_loss)
376         epochs.append(epoch)
377          
378         # 训练完一个Epoch,打印提示并绘制生成的图片
379         print("Epoch:", epoch, 
380                   'D_epoch_loss:{:.2f}'.format(D_epoch_loss),
381                   'G_epoch_loss:{:.2f}'.format(G_epoch_loss),
382                   'time:{:.2f}s'.format(epoch_finish-epoch_start))
383         
384         generate_images(gen_AB, test_input, epoch)  
385         
386         
387 
388 # 绘制loss函数
389 def D_G_loss_plot(D_loss, G_loss, epotchs):
390     
391     fig = plt.figure(figsize=(4, 4))
392     
393     plt.plot(epotchs, D_loss, label='D_loss')
394     plt.plot(epotchs, G_loss, label='G_loss')
395     plt.legend()
396     
397     plt.title("D_G_Loss")
398     plt.savefig('apple2orange/result/loss_at_epoch_{:04d}.png'.format(epotchs[len(epotchs)-1]))    
399     plt.close()
400 
401 D_G_loss_plot(D_loss, G_loss, epochs)
402 
403 torch.save(gen_AB, 'apple2orange/model/gen_AB_epoch_{:04d}.pt'.format(epochs[len(epochs)-1])) 
404 torch.save(gen_BA, 'apple2orange/model/gen_BA_epoch_{:04d}.pt'.format(epochs[len(epochs)-1])) 
View Code

 

Cycle GAN的训练效果

Epoch =1

 epoch=27

posted @ 2022-12-31 16:44  赵家小伙儿  阅读(143)  评论(0编辑  收藏  举报