starGAN代码分析
#参数设置
import sys sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages") from torchvision.datasets import ImageFolder from PIL import Image import torch import os import random c_dim=5 # dimension of domain labels (1st dataset) c2_dim=8 # dimension of domain labels (2nd dataset) celeba_crop_size=178 # crop size for the CelebA dataset rafd_crop_size=256 #crop size for the RaFD dataset image_size=128 #image resolution g_conv_dim=64 # number of conv filters in the first layer of G d_conv_dim=64 # number of conv filters in the first layer of D g_repeat_num = 6 #number of residual blocks in G d_repeat_num=6 #number of strided conv layers in D lambda_cls=1 #weight for domain classification loss lambda_rec=10 # weight for reconstruction loss lambda_gp=10 #'weight for gradient penalty # Training configuration. dataset='CelebA' # choices=['CelebA', 'RaFD', 'Both']) batch_size=16 # 'mini-batch size num_iters=200000 #number of total iterations for training D num_iters_decay=100000 #number of iterations for decaying lr g_lr=0.0001 #learning rate for G d_lr=0.0001 #learning rate for D n_critic=5 #number of D updates per each G update beta1=0.5 #beta1 for Adam optimizer beta2=0.999 #beta2 for Adam optimizer resume_iters=None #resume training from this step selected_attrs=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'] #selected attributes for the CelebA dataset' # Test configuration. test_iters=200000 #test model from this step # Miscellaneous. num_workers=1 mode='test' # choices=['train', 'test']) use_tensorboard=True # Directories. celeba_image_dir='../data/CelebA_nocrop/images/' if mode == 'train' else '../test/test/' attr_path='../data/list_attr_celeba.txt' if mode == 'train' else '../test/test_celeba.txt' rafd_image_dir='../data/RaFD/train/' log_dir='../test/logs' model_save_dir='../stargan/models' sample_dir='../test/samples' result_dir='../test/result' # Step size. log_step=10 sample_step=1000 model_save_step=10000 lr_update_step=1000
import tensorflow as tf #这是加载TensorBord class Logger(object): """Tensorboard logger.""" def __init__(self, log_dir): """Initialize summary writer.""" self.writer = tf.summary.FileWriter(log_dir) def scalar_summary(self, tag, value, step): """Add scalar summary.""" summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) self.writer.add_summary(summary, step)
#预处理和加载数据
from torch.utils import data from torchvision import transforms as T class CelebA(data.Dataset): """Dataset class for the CelebA dataset.""" def __init__(self, image_dir, attr_path, selected_attrs, transform, mode): """Initialize and preprocess the CelebA dataset.""" self.image_dir = image_dir self.attr_path = attr_path self.selected_attrs = selected_attrs self.transform = transform self.mode = mode self.train_dataset = [] self.test_dataset = [] self.attr2idx = {} self.idx2attr = {} self.preprocess() if mode == 'train': self.num_images = len(self.train_dataset) else: self.num_images = len(self.test_dataset) """ train_dataset的数据格式如下 '000003.jpg', [True, False, False, False, True]], """ def preprocess(self): """Preprocess the CelebA attribute file.""" lines = [line.rstrip() for line in open(self.attr_path, 'r')] all_attr_names = lines[1].split() for i, attr_name in enumerate(all_attr_names): self.attr2idx[attr_name] = i self.idx2attr[i] = attr_name lines = lines[2:] random.seed(1234) random.shuffle(lines) for i, line in enumerate(lines): split = line.split() filename = split[0] values = split[1:] label = [] for attr_name in self.selected_attrs: idx = self.attr2idx[attr_name] label.append(values[idx] == '1') if (i+1) < 2000: self.test_dataset.append([filename, label]) else: self.train_dataset.append([filename, label]) print('Finished preprocessing the CelebA dataset...') #该方法是继承torch里面的utils文件夹里面data文件夹里面的Dataset类 def __getitem__(self, index): """Return one image and its corresponding attribute label.""" dataset = self.train_dataset if self.mode == 'train' else self.test_dataset filename, label = dataset[index] image = Image.open(os.path.join(self.image_dir, filename)) return self.transform(image), torch.FloatTensor(label) def __len__(self): """Return the number of images.""" return self.num_images def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, batch_size=16, dataset='CelebA', mode='train', num_workers=1): """Build and return a data loader.""" transform = [] if mode == 'train': transform.append(T.RandomHorizontalFlip()) transform.append(T.CenterCrop(crop_size)) #to run only once transform.append(T.Resize(image_size)) transform.append(T.ToTensor()) transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) transform = T.Compose(transform) if dataset == 'CelebA': #dataset 是CelebA的一个对象 dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode) #加载自己私有数据,从folder.py里面进行加载,但是报错 elif dataset == 'RaFD': dataset = ImageFolder(image_dir, transform) #DataLoader类中dataset参数必须是 data.Dataset 类 data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=(mode=='train'), num_workers=num_workers) return data_loader #celeba_loader 相当于是 data_loader,而data_loader 是 torch.utils.data.dataloader.DataLoader的返回值 #其中 里面封装的dataset是CelebA 这个类的对象 celeba_loader = get_loader(celeba_image_dir, attr_path, selected_attrs,celeba_crop_size, image_size, batch_size,'CelebA', mode, num_workers)
网络模型结构
import torch.nn as nn import torch.nn.functional as F import numpy as np class ResidualBlock(nn.Module): """Residual Block with instance normalization.""" def __init__(self, dim_in, dim_out): super(ResidualBlock, self).__init__() self.main = nn.Sequential( nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True), nn.ReLU(inplace=True), nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True)) def forward(self, x): return x + self.main(x) class Generator(nn.Module): """Generator network.""" def __init__(self, conv_dim=64, c_dim=5, repeat_num=6): super(Generator, self).__init__() layers = [] # 第一个卷积层,输入为图像和label的串联,3表示图像为3通道,c_dim为label的维度, layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True)) layers.append(nn.ReLU(inplace=True)) # Down-sampling layers. curr_dim = conv_dim #这时候的64个维度 for i in range(2): layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True)) layers.append(nn.ReLU(inplace=True)) curr_dim = curr_dim * 2 #经过两次循环,这时 curr_dim 的维度为256 # Bottleneck layers. for i in range(repeat_num): layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim)) # Up-sampling layers. for i in range(2): layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False)) layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True)) layers.append(nn.ReLU(inplace=True)) curr_dim = curr_dim // 2 #最后的维度为3维 layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False)) layers.append(nn.Tanh()) self.main = nn.Sequential(*layers) def forward(self, x, c): #定义计算的过程 # Replicate spatially and concatenate domain information. c = c.view(c.size(0), c.size(1), 1, 1) #view 相当于Numpy中的reshape c = c.repeat(1, 1, x.size(2), x.size(3)) #沿着指定的维度重复tensor x = torch.cat([x, c], dim=1) #将输入图像x,label向量c,串联 return self.main(x) class Discriminator(nn.Module): """Discriminator network with PatchGAN.""" def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6): super(Discriminator, self).__init__() layers = [] layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) layers.append(nn.LeakyReLU(0.01)) curr_dim = conv_dim for i in range(1, repeat_num): layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)) layers.append(nn.LeakyReLU(0.01)) curr_dim = curr_dim * 2 kernel_size = int(image_size / np.power(2, repeat_num)) self.main = nn.Sequential(*layers) #将层加入到神经网络 self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)#D判读图像的真假 self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)#判别输入图像的label. def forward(self, x): h = self.main(x) #这里的X表示训练时的图像,经过main()后生成2048维数据 out_src = self.conv1(h) #out_src 表示图像的真假 out_cls = self.conv2(h) # out_cls 表示图像的标签 return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))
solver
from torchvision.utils import save_image import time import datetime class Solver(object): """Solver for training and testing StarGAN.""" def __init__(self, celeba_loader, rafd_loader): """Initialize configurations.""" # Data loader. self.celeba_loader = celeba_loader self.rafd_loader = rafd_loader # Model configurations. self.c_dim = c_dim self.c2_dim = c2_dim self.image_size = image_size self.g_conv_dim = g_conv_dim self.d_conv_dim = d_conv_dim self.g_repeat_num = g_repeat_num self.d_repeat_num = d_repeat_num self.lambda_cls = lambda_cls self.lambda_rec = lambda_rec self.lambda_gp = lambda_gp # Training configurations. self.dataset = dataset self.batch_size = batch_size self.num_iters = num_iters self.num_iters_decay = num_iters_decay self.g_lr = g_lr self.d_lr = d_lr self.n_critic = n_critic self.beta1 = beta1 self.beta2 = beta2 self.resume_iters = resume_iters self.selected_attrs = selected_attrs # Test configurations. self.test_iters = test_iters # Miscellaneous. self.use_tensorboard = use_tensorboard self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #self.device = torch.device('cpu') # Directories. self.log_dir = log_dir self.sample_dir = sample_dir self.model_save_dir = model_save_dir self.result_dir = result_dir # Step size. self.log_step = log_step self.sample_step = sample_step self.model_save_step = model_save_step self.lr_update_step = lr_update_step # Build the model and tensorboard. self.build_model() if self.use_tensorboard: self.build_tensorboard() def build_model(self): """Create a generator and a discriminator.""" if self.dataset in ['CelebA', 'RaFD']: self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num) self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) elif self.dataset in ['Both']: self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector. self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num) self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) #打印网络结构 #self.print_network(self.G, 'G') #self.print_network(self.D, 'D') self.G.to(self.device) self.D.to(self.device) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None): """Generate target domain labels for debugging and testing.""" # Get hair color indices. if dataset == 'CelebA': hair_color_indices = [] for i, attr_name in enumerate(selected_attrs): if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']: hair_color_indices.append(i) # hair_color_indices [0 ,1 ,2] c_trg_list = [] for i in range(c_dim): if dataset == 'CelebA': c_trg = c_org.clone() if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. c_trg[:, i] = 1 for j in hair_color_indices: if j != i: c_trg[:, j] = 0 else: c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value. elif dataset == 'RaFD': c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim) c_trg_list.append(c_trg.to(self.device)) return c_trg_list def denorm(self, x): """Convert the range from [-1, 1] to [0, 1].""" out = (x + 1) / 2 return out.clamp_(0, 1) def build_tensorboard(self): self.logger = Logger(self.log_dir) def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" print('Loading the trained models from step {}...'.format(resume_iters)) G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) def update_lr(self, g_lr, d_lr): """Decay learning rates of the generator and discriminator.""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr def reset_grad(self): """Reset the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() def classification_loss(self, logit, target, dataset='CelebA'): """Compute binary or softmax cross entropy loss.""" if dataset == 'CelebA': return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0) elif dataset == 'RaFD': return F.cross_entropy(logit, target) def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = torch.ones(y.size()).to(self.device) dydx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) return torch.mean((dydx_l2norm-1)**2) def label2onehot(self, labels, dim): """Convert label indices to one-hot vectors.""" batch_size = labels.size(0) out = torch.zeros(batch_size, dim) out[np.arange(batch_size), labels.long()] = 1 return out def train(self): """Train StarGAN within a single dataset.""" # Set data loader. if self.dataset == 'CelebA': data_loader = self.celeba_loader elif self.dataset == 'RaFD': data_loader = self.rafd_loader # Fetch fixed inputs for debugging. data_iter = iter(data_loader) x_fixed, c_org = next(data_iter) # x_fixed表示图像像素值 c_org表示真实标签值 tensor([[ 1., 0., 0., 1., 1.]]) x_fixed = x_fixed.to(self.device) c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) #print(c_fixed_list) #[tensor([[ 1., 0., 0., 1., 1.]]), tensor([[ 0., 1., 0., 1., 1.]]), tensor([[ 0., 0., 1., 1., 1.]]), # tensor([[ 1., 0., 0., 0., 1.]]), tensor([[ 1., 0., 0., 1., 0.]])] # Learning rate cache for decaying. g_lr = self.g_lr d_lr = self.d_lr # Start training from scratch or resume training. start_iters = 0 if self.resume_iters: #参数resume_iters 设置为none start_iters = self.resume_iters #可以不连续训练,从之前训练好后的结果处开始 self.restore_model(self.resume_iters) # Start training. print('Start training...') start_time = time.time() for i in range(start_iters, self.num_iters): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # # Fetch real images and labels. try: x_real, label_org = next(data_iter) except: data_iter = iter(data_loader) x_real, label_org = next(data_iter) # Generate target domain labels randomly. rand_idx = torch.randperm(label_org.size(0)) #tensor([ 0]) label_trg = label_org[rand_idx] #tensor([[ 1., 0., 0., 1., 1.]]) 真实label,从数据中取出 if self.dataset == 'CelebA': c_org = label_org.clone() c_trg = label_trg.clone() elif self.dataset == 'RaFD': c_org = self.label2onehot(label_org, self.c_dim) c_trg = self.label2onehot(label_trg, self.c_dim) x_real = x_real.to(self.device) # Input images. c_org = c_org.to(self.device) # Original domain labels. #print(c_org) tensor([[ 1., 0., 0., 1., 1.]] c_trg = c_trg.to(self.device) # Target domain labels. #print(c_trg) tensor([[ 1., 0., 0., 1., 1.]] label_org = label_org.to(self.device) # Labels for computing classification loss. label_trg = label_trg.to(self.device) # Labels for computing classification loss. # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # # Compute loss with real images. out_src, out_cls = self.D(x_real) """ out_src tensor(1.00000e-03 * [[[[-1.8202, 0.3373], [-0.5725, 0.4968]]]]) out_cls tensor(1.00000e-03 * [[ 0.3915, 2.0016, 0.4509, -2.0520, 2.4382]]) """ d_loss_real = - torch.mean(out_src) # d_loss_real最小,那么 out_src 最大==1 (针对图像) # d_loss_real = tensor(1.00000e-04 * 3.8965) d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset) #针对标签 # d_loss_cls = tensor(3.4666) # Compute loss with fake images. #将真实图像输入x_real和假的标签c_trg输入生成网络,得到生成图像x_fake, x_fake = self.G(x_real, c_trg) #x_fake 生成一个图像数据 out_src, out_cls = self.D(x_fake.detach()) """ out_src tensor(1.00000e-03 * [[[[-1.5289, 0.8110], [ 0.2153, 0.4624]]]]) out_cls tensor(1.00000e-03 * [[ 1.4681, 1.9497, 1.2743, -1.1915, 0.7609]]) """ d_loss_fake = torch.mean(out_src) #假图像为0 #tensor(1.00000e-05 *-1.0045) # Compute loss for gradient penalty. #计算梯度惩罚因子alpha,根据alpha结合x_real,x_fake,输入判别网络,计算梯度,得到梯度损失函数, alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) # alpha是一个随机数 tensor([[[[ 0.7610]]]]) x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) # x_hat是一个图像大小的张量数据,随着alpha的改变而变化 out_src, _ = self.D(x_hat) #x_hat 表示梯度惩罚因子 d_loss_gp = self.gradient_penalty(out_src, x_hat) #最终d_loss_gp 在0.9954~ 0.9956 波动 # Backward and optimize. #损失包含4项: # 1.真实图像判定为真 # 2.真实图像+错误标签记过G网络生成的图像判定为假 # 3.真实图像经过D网络的生成的标签与真实标签之间的差异损失 # 4.真实图像和 真实图像+错误标签记过G网络生成的图像 融合的梯度惩罚因子 d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging. loss = {} loss['D/loss_real'] = d_loss_real.item() loss['D/loss_fake'] = d_loss_fake.item() loss['D/loss_cls'] = d_loss_cls.item() loss['D/loss_gp'] = d_loss_gp.item() # =================================================================================== # # 3. Train the generator # # =================================================================================== # #生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像(重建) if (i+1) % self.n_critic == 0: # Original-to-target domain. #将真实图像输入x_real和假的标签c_trg输入生成网络,得到生成图像x_fake x_fake = self.G(x_real, c_trg) print("c_trg:",c_trg) out_src, out_cls = self.D(x_fake) g_loss_fake = - torch.mean(out_src) #这里是对抗损失,希望生成的假图像为1 g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)#向目标标签进行转化 # Target-to-original domain. x_reconst = self.G(x_fake, c_org) print("c_org:",c_org) sys.exit(0) g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) # Backward and optimize. g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging. loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_rec'] = g_loss_rec.item() loss['G/loss_cls'] = g_loss_cls.item() # =================================================================================== # # 4. Miscellaneous # # =================================================================================== # # Print out training information. if (i+1) % self.log_step == 0: et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i+1) # Translate fixed images for debugging. if (i+1) % self.sample_step == 0: with torch.no_grad(): x_fake_list = [x_fixed] for c_fixed in c_fixed_list: x_fake_list.append(self.G(x_fixed, c_fixed)) x_concat = torch.cat(x_fake_list, dim=3) sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1)) save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) print('Saved real and fake images into {}...'.format(sample_path)) # Save model checkpoints. if (i+1) % self.model_save_step == 0: G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) print('Saved model checkpoints into {}...'.format(self.model_save_dir)) # Decay learning rates. if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): g_lr -= (self.g_lr / float(self.num_iters_decay)) d_lr -= (self.d_lr / float(self.num_iters_decay)) self.update_lr(g_lr, d_lr) print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) def test(self): """Translate images using StarGAN trained on a single dataset.""" # Load the trained generator. self.restore_model(test_iters) # Set data loader. if self.dataset == 'CelebA': data_loader = celeba_loader elif self.dataset == 'RaFD': data_loader = rafd_loader with torch.no_grad(): for i, (x_real, c_org) in enumerate(data_loader): # Prepare input images and target domain labels. x_real = x_real.to(self.device) c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) # Translate images. x_fake_list = [x_real] for c_trg in c_trg_list: x_fake_list.append(self.G(x_real, c_trg)) # Save the translated images. x_concat = torch.cat(x_fake_list, dim=3) result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) print('Saved real and fake images into {}...'.format(result_path))
开始训练
rafd_loader = None solver = Solver(celeba_loader, rafd_loader) solver.train()