[论文理解] Virtual Adversarial Training: a Regularization Method for Supervised and Semi-supervised Learning
Virtual Adversarial Training: a Regularization Method for Supervised and Semi-supervised Learning
简介
本文是17年半监督学习的一篇文章,受对抗训练的启发,将对抗训练的范式用于提升半监督学习,并且取得了非常好的效果。不同于最近一直比较火的对比学习,这些稍微“传统”一点的方法我觉得还是有一定研究价值的,对比学习利用的增广还是太多利用了人类的先验,并不普适。
Intution
我们知道,大多数场景下,神经网络的输入都是连续的,那么如果我们能让神经网络平滑(对x的领域内的输入有相似的输出),那么就可以保证相似的输入通过同一神经网络得到相似的输出,基于这样的想法,那么自然就可以给没有标签的样本一个与之输入相近的伪标签,这一算法称之为label propagation。然而这样做并不总是work,因为最近有大量的工作表明,神经网络很容易收到输入微小变动的攻击,即输入微小变动一点,输出天差万别。对抗样本的生成会使得网络遭到攻击,从而上面让网络平滑的想法就无法实现。于是本文就想,能否利用生成对抗样本的方法,使得输入微小改变,但是仍然让改变的样本和改变之前的样本的到相似的输出?(这样的方法应该被广泛用于对抗攻击的训练当中,但是并没有人将其用于半监督学习)
当前的半监督算法,一类是通过增广,保证增广前后样本具有相似的输出,这种可以理解为让模型“平滑”;另一类是通过生成对抗网络,生成样本填充流形的低密度区域,此类方法并不需要模型“平滑”。但是后者往往缺乏合理的解释,本文主要着手于前者研究。
Method
本文的VAT(Virtual Adversarial Training)方法最初的定义为这样的Loss:
这里\(x_*\)是有标签或无标签的样本。
上面用于优化两分布的损失函数可以使用KL散度,而最重要的是如何计算\(r_{qadv}\)
我们先将\(D[q(y|x_*), p(y|x_*+r,\theta)]\)简写为\(D(r, x_*, \theta)\),假定\(p(y|x_*,\theta)\)关于\(\theta\)和\(x\)是二阶处处可微的,我们知道,当r=0的时候,\(D(r, x_*, \theta)\)必定取得最小值,所以有\(\nabla_rD(r, x_*, \theta) |_{r=0} = 0\).
根据泰勒展开有:
而我们前面对r有有一个约束\(||r||_2 \leq \epsilon\)。根据瑞利熵原理,上式取得最大值时,\(r\)应为最大特征值对应的特征向量:
其中上划线代表将任意一非零向量投影到其对应方向的单位向量。
然后问题就变成了求海森矩阵的特征向量的问题了。
一般的,我们可能在numpy中会直接调用接口来求特征向量,但是获得海森矩阵的计算还是挺大的,如果能根据一阶的梯度来计算,运算就会小很多,但本文采用了幂迭代法来求解。
算法也就两步:
Input: matrix H
Output: V main eigenvector of H
initialize V randomly;
repeat {
V <- HV
V <- V / ||V||
} until convergence;
令每次迭代
其中初始的d为随机向量,最终d将收敛到特征向量u。
因而可以从一个随机向量d出发,先对海森矩阵H做近似:
因此每一步迭代就变成了
这玩意变成梯度了,因此可以通过pytorch等自动求导到工具来实现了。
最终的Loss就是有监督的loss和上述adv loss加权。
Coding
import torch
import torch.nn as nn
import torch.nn.functional as F
def criterion(pred_p, pred_q):
p = F.softmax(pred_p, dim = 1)
q = F.softmax(pred_q, dim = 1)
return F.kl_div(p, q)
def vat_loss(model, x, iters, ep = 0.1):
model.eval()
pred = model(x)
# 1. 初始化随机向量
d = torch.rand(x.shape)
d = F.normalize(d)
# 2. 幂迭代
for i in range(iters):
r = ep * d
r.requires_grad = True
d_ = criterion(pred, model(x + r))
d_.backward()
d = F.normalize(r.grad)
model.zero_grad()
model.train()
r_adv = ep * d
loss_adv = criterion(pred, model(x + r_adv))
return loss_adv
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, x):
return x**2
if __name__ == "__main__":
x = torch.randn(2,3)
net = SimpleNet()
loss = vat_loss(net, x, 3)
print(loss)
结果
在cifar10上复现的结果大致是: