Note:[ wechat:Y466551 | 可加勿骚扰,付费咨询 ]
论文信息
论文标题:Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning
论文作者:Takeru Miyato, S. Maeda, Masanori Koyama, S. Ishii
论文来源:2020 ECCV
论文地址:download
论文代码:download
视屏讲解:click
1 前言
摘要
我们提出了一种基于虚拟对抗损失的新正则化方法:给定输入的条件标签分布的局部平滑度的新度量。 虚拟对抗损失被定义为每个输入数据点周围的条件标签分布对局部扰动的鲁棒性。 与对抗训练不同,我们的方法定义了没有标签信息的对抗方向,因此适用于半监督学习。 因为我们平滑模型的方向只是“虚拟”对抗,所以我们称我们的方法为虚拟对抗训练(VAT)。 增值税的计算成本相对较低。 对于神经网络,虚拟对抗损失的近似梯度可以通过不超过两对前向和反向传播来计算。 在我们的实验中,我们将 VAT 应用于多个基准数据集上的监督和半监督学习任务。 通过基于熵最小化原理对算法进行简单增强,我们的 VAT 在 SVHN 和 CIFAR-10 上实现了半监督学习任务的最先进性能。
正文
VAT:一种普适性的,可以用来代替传统 正则化 和 对抗训练 的神经网络模型训练鲁棒性能提升手段,具有快捷、有效、参数少的优点,并天然契合半监督学习。
正则化项很大程度上是用来平滑模型边界,传统的通过对输入添加噪声干扰提升鲁棒性能其实也是一种平滑的手段,已有研究表明,对输入数据的局部、随机的扰动在半监督任务里面显得非常有效,但于此同时,也有其他研究证明了这种方法存在其很强的局限性—他们发现传统的随机扰动的方法会让模型对某些特定方向的微小扰动表现得很敏感、脆弱,这些方向被称作 对抗方向。
2 介绍
为了解决模型对对抗方向上的噪声的欠鲁棒性,本文提出了一种对抗训练方法:
公式的目的是想在对抗方向上找一个模小于 的扰动使得 损失最大,一般情况下很难找到一个这样精确的 ,所以本文采用线性估计的方法找到最近似的扰动:
当范数为 时,对抗性扰动可以近似为
3 方法
3.1 虚拟对抗训练
对抗训练是一种成功的方法,适用于许多监督问题。 但是,并非始终提供完整的标签信息。 让 代表 或 。虚拟对抗训练目标函数如下:
在本研究中,使用当前估计值 代替 。 通过这种折中,得到了 的再现:
损失 可以被认为是当前模型在每个输入数据点 处的局部平滑度的负度量,它的减少将使模型在每个数据点处变得平滑。在本研究中提出的正则化项是所有输入数据点的 的平均值:
完整的目标函数:
其中, 是标记数据集的负对数似然。 VAT 是一种使用正则化器 的训练方法。
VAT 的一个显着优点是只有两个标量值超参数:
-
- 对抗方向的范数约束 ;
- 控制负对数似然之间的相对平衡的正则化系数 和正则化器 ;
实验:
上图直观的显示了 VAT 在半监督任务上的表现的举例,可以看到第二行第二列,在一开始模型迭代伦次较少的情况下,有大量的无标签数据(那些大量的灰色点)会有较高的 LDS(深蓝色),这是因为一开始的模型对相同类别的数据点预测了不同的标签(见同列第一行),VAT 会给予这些 LDS 较高数据点更大的压力,来迫使模型让数据点间的边界平滑。
代码:

class VATLoss(nn.Module):
def __init__(self, xi=10.0, eps=1.0, ip=1):
"""VAT loss
:param xi: hyperparameter of VAT (default: 10.0)
:param eps: hyperparameter of VAT (default: 1.0)
:param ip: iteration times of computing adv noise (default: 1)
"""
super(VATLoss, self).__init__()
self.xi = xi #10.0
self.eps = eps #1.0
self.ip = ip #1
def forward(self, model, x):
with torch.no_grad():
pred = F.softmax(model(x), dim=1) #torch.Size([32, 10])
# prepare random unit tensor
d = torch.rand(x.shape).sub(0.5).to(x.device) #torch.Size([32, 3, 32, 32])
d = _l2_normalize(d)
with _disable_tracking_bn_stats(model):
# calc adversarial direction
for _ in range(self.ip):
d.requires_grad_()
pred_hat = model(x + self.xi * d)
logp_hat = F.log_softmax(pred_hat, dim=1)
adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean')
adv_distance.backward()
d = _l2_normalize(d.grad)
model.zero_grad()
# calc LDS
r_adv = d * self.eps #r_adv.requires_grad = False
pred_hat = model(x + r_adv) # x + r_adv .requires_grad = False
logp_hat = F.log_softmax(pred_hat, dim=1)
lds = F.kl_div(logp_hat, pred, reduction='batchmean')
return lds

def train(args, model, device, data_iterators, optimizer):
model.train()
for i in tqdm(range(args.iters)):
if i % args.log_interval == 0:
ce_losses = utils.AverageMeter()
vat_losses = utils.AverageMeter()
prec1 = utils.AverageMeter()
x_l, y_l = next(data_iterators['labeled'])
x_ul, _ = next(data_iterators['unlabeled'])
x_l, y_l = x_l.to(device), y_l.to(device)
x_ul = x_ul.to(device)
optimizer.zero_grad()
vat_loss = VATLoss(xi=args.xi, eps=args.eps, ip=args.ip) # 10.0 1.0 1
cross_entropy = nn.CrossEntropyLoss()
lds = vat_loss(model, x_ul)
output = model(x_l)
classification_loss = cross_entropy(output, y_l)
loss = classification_loss + args.alpha * lds
loss.backward()
optimizer.step()
因上求缘,果上努力~~~~ 作者:别关注我了,私信我吧,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/17343354.html
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!