06_CycleGAN
cycle GAN主要用于图像之间的转换,如图像风格转换.
Cycle GAN原理
CycleGAN可以完成从一个模式到另外一个模式的转换,转换,比如从男人到女人:
CycleGAN适用于非配对的图像到图像转换,CycleGAN解决了模型需要成对数据进行训练的困难。
CycleGAN的论文地址 :https://arxiv.org/pdf/1703.10593.pdf
CycleGAN的原理可以概述为:将一类图片转换成另一类图片 。也就是说,现在有两个样本空间, X和Y,我们希望把X空间中的样本转换成Y空间中的样本。(获取一个数据集的特征,并转化成另一个数据集的特征)。
这样来看: 实际的目标就是学习从X到Y的映射。我们假设这个映射为F。它就对应着GAN中的 生成器 , F可以将X中的图片x转换为Y中的图片F(x)。对于生成的图片,我们还需要GAN中的 判别器 来判别它是否为真实图片,由此构成对抗生成网络。
Cycle GAN 模型及理论
从理论上讲,对抗训练可以学习和产生与目标域Y相同分布的输出,但会产生一些问题。在足够大的样本容量下,网络可以将相同的输入图像集合映射到目标域中图像的任何随机排列,其中任何学习的映射可以归纳出与目标分布匹配的输出分布(即:映射F完全可以将所有x都映射为Y空间中的同一张图片,使损失无效化) 。因此,单独的对抗损失Loss不能保证学习函数可以将单个输入Xi映射到期望的输出Yi 。对此,作者又提出了所谓的“循环一致性损失” (cycle consistency loss).
我们希望能够把 domain A 的图片(命名为 a)转化为 domain B 的图片(命名为图片 b)。为了实现这个过程,我们需要两个生成器 G_AB 和G_BA,分别把 domain A 和 domain B 的图片进行互相转换;将X的图片转换到Y空间后,应该还可以转换回来。这样就杜绝模型把所有X的图片都转换为Y空间中的同一张图片了。
最后为了训练这个单向 GAN 需要两个 loss 分别是: 生成器的重建loss,判别器的判别loss。
生成器损失和判别损失
判别 loss: 判别器 D_B 是用来判断输入的图片是否是真实的domain B图片。
生成 loss:生成器用来重建图片 a,目的是希望生成的图片 G_BA(G_AB(a)) 和原图 a 尽可能的相似,那么可以很简单的采取 L1 loss 或者 L2 loss。最后生成 loss 就表示为:
CycleGAN 其实就是一个 A→B 单向 GAN 加上一个B→A 单向 GAN。两个 GAN 共享两个生成器,然后各自带一个判别器,所以加起来总共有两个判别器和两个生成器。 一个单向 GAN 有两个 loss,而 CycleGAN 加起来总共有四个 loss。
循环一致性损失
GAN网络的对抗loss之外,还有一个cycle-loss,也就是循环一致损失。因为网络需要保证生成的图像必须保留有原始图像的特性,所以如果我们使用生成器GenratorA-B生成一张假图像,那么要能够使用另外一个生成器GenratorB-A来努力恢复成原始图像。此过程必须满足循环一致性。
Cycle GAN损失
循环一致损失:
总的损失:
在编写代码时还有个loss: identity loss,可以理解为,生成器是负责域x到域y的图像生成,如果输入域y的图片还是应该生成域y的图片y",计算y‘’ 和输入 y 的loss
Cycle GAN训练
Generator采用的是Perceptual losses for real-time style transfer and super-resolution 一文中的网络结构 (论文地址:https://arxiv.org/abs/1603.08155)。
一个resblock组成的网络,下采样部分采用stride 卷积,上采样部分采用反卷积 。
Discriminator采用的仍是pix2pix中的PatchGAN结构,大小为70x70;
定义四个损失函数,分别优化训练G和D,两个生成器共享权重,两个鉴别器也共享权重训练。
Lr=0.0002。对于前100个周期,保持相同的学习速率0.0002,然后在接下来的100个周期内线性衰减到0。
Cycle GAN局限性
对颜色、纹理等的转换效果比较好, 但是:
1.)会在改变物体的同时改变背景。
2.)缺少多样性
Source domain 和 target domain 的维度应该是不一样的,比如笑和不笑,笑自然是闭着嘴,但是不笑的程度多种多样,可能有微笑,哈哈大笑等等。又比如无眼镜就是一种,有眼镜可以是各式各样的眼镜。
3.)在需要几何变化的任务上表现一般。
Cycle GAN的代码实现(pytorch版本)

1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 from torch.utils import data 5 import torchvision 6 from torchvision import transforms 7 8 import numpy as np 9 import matplotlib.pyplot as plt 10 import os 11 import glob 12 from PIL import Image 13 import itertools 14 15 import time 16 17 dataNames = os.listdir('apple2orange/') 18 print(dataNames) 19 20 def plotImg(data_path, name): 21 22 23 plt.figure(figsize=(12, 8)) 24 for i, img_path in enumerate(data_path[:4]): 25 img = Image.open(img_path) 26 np_img = np.array(img) 27 plt.subplot(2,2,i+1) 28 plt.imshow(np_img) 29 plt.title(str(np_img.shape)) 30 31 plt.show() 32 plt.savefig('apple2orange/result/image_{:s}.png'.format(name)) 33 plt.close() 34 35 36 apples_path = glob.glob('apple2orange/testA/*.jpg') 37 oranges_path = glob.glob('apple2orange/testB/*.jpg') 38 39 # print(len(apples_path), len(oranges_path)) 40 # print(apples_path[:3], oranges_path[:3]) 41 42 # plotImg(apples_path, "apple") 43 # plotImg(oranges_path, "oranges") 44 45 46 transform = transforms.Compose([ 47 transforms.ToTensor(), # 取值范围会被归一化到(0, 1)之间 48 transforms.Normalize(mean=0.5, std=0.5) # 设置均值和方差均为0.5 49 ]) 50 51 class Apple2orange_Dataset(data.Dataset): 52 def __init__(self, imgs_path): 53 self.imgs_path = imgs_path 54 55 def __getitem__(self, index): 56 img_path = self.imgs_path[index] 57 pil_img = Image.open(img_path) 58 pil_img = transform(pil_img) 59 return pil_img 60 61 def __len__(self): 62 return len(self.imgs_path) 63 64 apples_dataset = Apple2orange_Dataset(apples_path) 65 oranges_dataset = Apple2orange_Dataset(oranges_path) 66 print(len(apples_dataset), len(oranges_dataset)) # 打印数据集大小 67 68 69 BTACH_SIZE = 32 # 批次大小 70 apples_dl = torch.utils.data.DataLoader( 71 apples_dataset, 72 batch_size=BTACH_SIZE, 73 shuffle=True 74 ) 75 76 oranges_dl = torch.utils.data.DataLoader( 77 oranges_dataset, 78 batch_size=BTACH_SIZE, 79 shuffle=True 80 ) 81 82 apples_batch = next(iter(apples_dl)) # 返回一个批次的训练数据 83 oranges_batch = next(iter(oranges_dl)) # 返回一个批次的训练数据 84 85 # 绘制批次中前3对图片 86 def pltImgAndMask(imgData, maskData, imgName, maskName): 87 88 # 绘制批次中前3对图片 89 plt.figure(figsize=(12, 18)) 90 91 for i, (img, mask) in enumerate(zip(imgData[:3], maskData[:3])): 92 93 # 设置channel最后,并还原到取值0-1之间 94 img = (img.permute(1, 2, 0).numpy() +1)/2 95 mask = (mask.permute(1, 2, 0).numpy()+1)/2 96 97 plt.subplot(3, 2, 2*i+1) 98 plt.title('apple image') 99 plt.imshow(img) 100 101 plt.subplot(3, 2, 2*i+2) 102 plt.title('orange image') 103 plt.imshow(mask) 104 105 plt.savefig('apple2orange/result/image_{:s}_{:s}.png'.format(imgName, maskName)) 106 plt.close() 107 108 imgName = "apple" 109 maskName = "orrange" 110 # pltImgAndMask(apples_batch, oranges_batch, imgName, maskName) 111 112 113 apples_path_test = glob.glob("apple2orange/testA/*.jpg") 114 oranges_path_test = glob.glob("apple2orange/testB/*.jpg") 115 116 apples_dataset_test = Apple2orange_Dataset(apples_path_test) 117 oranges_dataset_test = Apple2orange_Dataset(oranges_path_test) 118 119 apples_dl_test = torch.utils.data.DataLoader( 120 apples_dataset_test, 121 batch_size=BTACH_SIZE, 122 shuffle=True 123 ) 124 125 oranges_dl_test = torch.utils.data.DataLoader( 126 oranges_dataset_test, 127 batch_size=BTACH_SIZE, 128 shuffle=True 129 ) 130 131 class Downsample(nn.Module): 132 def __init__(self, in_channels, out_channels): 133 super(Downsample, self).__init__() 134 self.conv_relu = nn.Sequential( 135 nn.Conv2d(in_channels, out_channels, 136 kernel_size=3, stride=2, padding=1), 137 nn.LeakyReLU(inplace=True), 138 ) 139 self.bn = nn.InstanceNorm2d(out_channels) 140 def forward(self, x, is_bn=True): 141 x = self.conv_relu(x) 142 if is_bn: 143 x = self.bn(x) 144 return x 145 146 class Upsample(nn.Module): 147 def __init__(self, in_channels, out_channels): 148 super(Upsample, self).__init__() 149 self.upconv_relu = nn.Sequential( 150 nn.ConvTranspose2d(in_channels, out_channels, 151 kernel_size=3, 152 stride=2, 153 padding=1, 154 output_padding=1), 155 nn.LeakyReLU(inplace=True) 156 ) 157 self.bn = nn.InstanceNorm2d(out_channels) 158 159 def forward(self, x, is_drop=False): 160 x = self.upconv_relu(x) 161 x = self.bn(x) 162 if is_drop: 163 x = F.dropout2d(x) 164 return x 165 166 class Generator(nn.Module): 167 def __init__(self): 168 super(Generator, self).__init__() 169 self.down1 = Downsample(3, 64) 170 self.down2 = Downsample(64, 128) 171 self.down3 = Downsample(128, 256) 172 self.down4 = Downsample(256, 512) 173 self.down5 = Downsample(512, 512) 174 self.down6 = Downsample(512, 512) 175 176 self.up1 = Upsample(512, 512) 177 self.up2 = Upsample(1024, 512) 178 self.up3 = Upsample(1024, 256) 179 self.up4 = Upsample(512, 128) 180 self.up5 = Upsample(256, 64) 181 182 self.last = nn.ConvTranspose2d(128, 3, 183 kernel_size=3, 184 stride=2, 185 padding=1, 186 output_padding=1) 187 188 189 def forward(self, x): 190 x1 = self.down1(x, is_bn=False) # torch.Size([8, 64, 128, 128]) 191 x2 = self.down2(x1) # torch.Size([8, 128, 64, 64]) 192 x3 = self.down3(x2) # torch.Size([8, 256, 32, 32]) 193 x4 = self.down4(x3) # torch.Size([8, 512, 16, 16]) 194 x5 = self.down5(x4) # torch.Size([8, 512, 8, 8]) 195 x6 = self.down6(x5) # torch.Size([8, 512, 4, 4]) 196 197 x6 = self.up1(x6, is_drop=True) # torch.Size([8, 512, 8, 8]) 198 x6 = torch.cat([x5, x6], dim=1) # torch.Size([8, 1024, 8, 8]) 199 200 x6 = self.up2(x6, is_drop=True) # torch.Size([8, 512, 16, 16]) 201 x6 = torch.cat([x4, x6], dim=1) # torch.Size([8, 1024, 16, 16]) 202 203 x6 = self.up3(x6, is_drop=True) 204 x6 = torch.cat([x3, x6], dim=1) 205 206 x6 = self.up4(x6) 207 x6 = torch.cat([x2, x6], dim=1) 208 209 x6 = self.up5(x6) 210 x6 = torch.cat([x1, x6], dim=1) 211 212 x6 = torch.tanh(self.last(x6)) 213 return x6 214 215 # 定义判别器 216 class Discriminator(nn.Module): 217 def __init__(self): 218 super(Discriminator, self).__init__() 219 self.down1 = Downsample(3, 64) # 128 220 self.down2 = Downsample(64, 128) # 64 221 self.last = nn.Conv2d(128, 1, 3) 222 223 def forward(self, img): 224 x = self.down1(img) 225 x = self.down2(x) 226 x = torch.sigmoid(self.last(x)) 227 return x 228 229 # 定义判别器 230 class Discriminator(nn.Module): 231 def __init__(self): 232 super(Discriminator, self).__init__() 233 self.down1 = Downsample(3, 64) # 128 234 self.down2 = Downsample(64, 128) # 64 235 self.last = nn.Conv2d(128, 1, 3) 236 237 def forward(self, img): 238 x = self.down1(img) 239 x = self.down2(x) 240 x = torch.sigmoid(self.last(x)) 241 return x 242 243 device = "cuda:1" if torch.cuda.is_available() else "cpu" 244 gen_AB = Generator().to(device) 245 gen_BA = Generator().to(device) 246 dis_A = Discriminator().to(device) 247 dis_B = Discriminator().to(device) 248 249 250 bceloss_fn = torch.nn.BCELoss() # 定义损失函数 251 l1loss_fn = torch.nn.L1Loss() 252 253 gen_optimizer = torch.optim.Adam(itertools.chain(gen_AB.parameters(), gen_BA.parameters()), 254 lr=2e-4, betas=(0.5, 0.999)) 255 dis_A_optimizer = torch.optim.Adam(dis_A.parameters(), lr=2e-4, betas=(0.5, 0.999)) 256 dis_B_optimizer = torch.optim.Adam(dis_B.parameters(), lr=2e-4, betas=(0.5, 0.999)) 257 258 # 绘制测试结果图像 259 def generate_images(model, test_input, epoch): 260 prediction = model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy() 261 test_input = test_input.permute(0, 2, 3, 1).cpu().numpy() 262 263 plt.figure(figsize=(100, 50)) 264 title = ['Input Image', 'Predicted Image'] 265 for i in range(4): 266 plt.subplot(2, 4, i+1) 267 plt.title(title[0]) 268 plt.imshow(test_input[i] * 0.5 + 0.5) 269 plt.axis('off') 270 for i in range(4): 271 plt.subplot(2, 4, i+5) 272 plt.title(title[1]) 273 plt.imshow(prediction[i] * 0.5 + 0.5) 274 plt.axis('off') 275 276 plt.savefig('apple2orange/result/image_at_epoch_{:04d}.png'.format(epoch)) 277 278 279 test_batch = next(iter(apples_dl_test)) 280 # test_input = torch.unsqueeze(test_batch[0], 0).to(device) 281 test_input = test_batch[:4].to(device) 282 283 284 D_loss = [] # 记录训练过程中判别器loss变化 285 G_loss = [] # 记录训练过程中生成器loss变化 286 epochs = [] 287 288 #开始训练 289 for epoch in range(300): 290 291 epoch_start = time.time() 292 293 D_epoch_loss=0 294 G_epoch_loss=0 295 for step, (real_A, real_B) in enumerate(zip(apples_dl, oranges_dl)): 296 real_A = real_A.to(device) 297 real_B = real_B.to(device) 298 299 # GAN 训练 300 gen_optimizer.zero_grad() 301 302 # identity loss 303 same_B = gen_AB(real_B) 304 identity_B_loss = l1loss_fn(same_B, real_B) 305 same_A = gen_BA(real_A) 306 identity_A_loss = l1loss_fn(same_A, real_A) 307 308 # GAN loss 309 fake_B = gen_AB(real_A) 310 D_pred_fake_B = dis_B(fake_B) 311 gan_loss_AB = bceloss_fn(D_pred_fake_B, 312 torch.ones_like(D_pred_fake_B, device=device)) 313 314 fake_A = gen_BA(real_B) 315 D_pred_fake_A = dis_A(fake_A) 316 gan_loss_BA = bceloss_fn(D_pred_fake_A, 317 torch.ones_like(D_pred_fake_A, device=device)) 318 319 # cycle consistanse loss 320 recovered_A = gen_BA(fake_B) 321 cycle_loss_ABA = l1loss_fn(recovered_A, real_A) 322 323 recovered_B = gen_AB(fake_A) 324 cycle_loss_BAB = l1loss_fn(recovered_B, real_B) 325 326 # total_loss 327 g_loss = (identity_B_loss + identity_A_loss + gan_loss_AB + gan_loss_BA 328 + cycle_loss_ABA + cycle_loss_BAB) 329 330 g_loss.backward() 331 gen_optimizer.step() 332 333 # dis_A 训练 334 dis_A_optimizer.zero_grad() 335 dis_A_real_output = dis_A(real_A) # 判别器输入真实图片 336 dis_A_real_loss = bceloss_fn(dis_A_real_output, 337 torch.ones_like(dis_A_real_output, device=device)) 338 339 dis_A_fake_output = dis_A(fake_A.detach()) # 判别器输入生成图片 340 dis_A_fake_loss = bceloss_fn(dis_A_fake_output, 341 torch.zeros_like(dis_A_fake_output, device=device)) 342 343 dis_A_loss = (dis_A_real_loss + dis_A_fake_loss)*0.5 344 345 dis_A_loss.backward() 346 dis_A_optimizer.step() 347 348 349 # dis_B 训练 350 dis_B_optimizer.zero_grad() 351 dis_B_real_output = dis_B(real_B) # 判别器输入真实图片 352 dis_B_real_loss = bceloss_fn(dis_B_real_output, 353 torch.ones_like(dis_B_real_output, device=device)) 354 355 dis_B_fake_output = dis_B(fake_B.detach()) # 判别器输入生成图片 356 dis_B_fake_loss = bceloss_fn(dis_B_fake_output, 357 torch.zeros_like(dis_B_fake_output, device=device)) 358 359 dis_B_loss = (dis_B_real_loss + dis_B_fake_loss)*0.5 360 361 dis_B_loss.backward() 362 dis_B_optimizer.step() 363 364 # 打印 loss 变化 365 with torch.no_grad(): 366 D_epoch_loss += (dis_A_loss + dis_B_loss).item() 367 G_epoch_loss += g_loss.item() 368 369 epoch_finish = time.time() 370 371 with torch.no_grad(): 372 D_epoch_loss /= step 373 G_epoch_loss /= step 374 D_loss.append(D_epoch_loss) 375 G_loss.append(G_epoch_loss) 376 epochs.append(epoch) 377 378 # 训练完一个Epoch,打印提示并绘制生成的图片 379 print("Epoch:", epoch, 380 'D_epoch_loss:{:.2f}'.format(D_epoch_loss), 381 'G_epoch_loss:{:.2f}'.format(G_epoch_loss), 382 'time:{:.2f}s'.format(epoch_finish-epoch_start)) 383 384 generate_images(gen_AB, test_input, epoch) 385 386 387 388 # 绘制loss函数 389 def D_G_loss_plot(D_loss, G_loss, epotchs): 390 391 fig = plt.figure(figsize=(4, 4)) 392 393 plt.plot(epotchs, D_loss, label='D_loss') 394 plt.plot(epotchs, G_loss, label='G_loss') 395 plt.legend() 396 397 plt.title("D_G_Loss") 398 plt.savefig('apple2orange/result/loss_at_epoch_{:04d}.png'.format(epotchs[len(epotchs)-1])) 399 plt.close() 400 401 D_G_loss_plot(D_loss, G_loss, epochs) 402 403 torch.save(gen_AB, 'apple2orange/model/gen_AB_epoch_{:04d}.pt'.format(epochs[len(epochs)-1])) 404 torch.save(gen_BA, 'apple2orange/model/gen_BA_epoch_{:04d}.pt'.format(epochs[len(epochs)-1]))
Cycle GAN的训练效果
Epoch =1
epoch=27

【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· 写一个简单的SQL生成工具
· AI 智能体引爆开源社区「GitHub 热点速览」
· C#/.NET/.NET Core技术前沿周刊 | 第 29 期(2025年3.1-3.9)