WGAN讲解

参考:https://blog.csdn.net/omnispace/article/details/54942668

上面这篇博客讲的很好!

PS:

(1)wgan中的weight cliping后面又被升级为gradient penalty;

参考:http://www.sohu.com/a/138121777_494939

代码:

from torch.autograd import grad
#gradient penalty , autograd way
LAMBDA_GRAD_PENALTY = 1.0
alpha = torch.rand(BATCH_SIZE, 1, 1, 1).cuda()
#pred_penalty是生成的分布,D_gt_v是真实分布
differences = pred_penalty - D_gt_v 
interpolates = D_gt_v + (alpha * differences)
D_interpolates = model_D(interpolates)
gradients = grad(outputs=D_interpolates, inputs=interpolates, grad_outputs=torch.ones(D_interpolates.size()).cuda(), create_graph=False, retain_graph=True, only_inputs=True)[0]
gradient_penalty = torch.mean(torch.sqrt(torch.sum((gradients - 1) ** 2 , dim = (1 , 2 , 3)))) * LAMBDA_GRAD_PENALTY
loss_D += gradient_penalty

 

posted @ 2019-03-21 14:58  outthinker  阅读(1066)  评论(0编辑  收藏  举报