Bai T., Chen J., Zhao J., Wen B., Jiang X., Kot A. Feature Distillation With Guided Adversarial Contrastive Learning. arXiv preprint arXiv 2009.09922, 2020.
概
本文是通过固定教师网络(具有鲁棒性), 让学生网络去学习教师网络的鲁棒特征. 相较于一般的distillation 方法, 本文新加了reweight机制, 另外其损失函数非一般的交叉熵, 而是最近流行的对比损失.
主要内容
本文的思想是利用robust的教师网络f t f t 来辅助训练学生网络f s f s , 假设有输入( x , y ) ( x , y ) , 通过网络得到特征
t + := f t ( x ) , s + := f s ( x ) , t + := f t ( x ) , s + := f s ( x ) ,
则( t + , s + ) ( t + , s + ) 构成正样本对, 自然我们需要学生网络提取的特征s + s + 能够逼近t + t + , 进一步, 构建负样本对, 采样样本{ x − 1 , x − 2 , … , x − k } { x 1 − , x 2 − , … , x k − } , 同时得到负样本对( t + , s − i ) ( t + , s i − ) , 其中s − i = f s ( x − i ) s i − = f s ( x i − ) . 总的样本对就是
S p a i r := { ( t + , s + ) , ( t + , s − 1 ) , … , ( t + , s − k ) } . S p a i r := { ( t + , s + ) , ( t + , s 1 − ) , … , ( t + , s k − ) } .
根据负样本采样的损失, 最大化
J ( θ ) := E ( t , s ) ∼ p ( t , s ) log P ( 1 | t , s ; θ ) + E ( t , s ) ∼ q ( t , s ) log P ( 0 | t , s ; θ ) . J ( θ ) := E ( t , s ) ∼ p ( t , s ) log P ( 1 | t , s ; θ ) + E ( t , s ) ∼ q ( t , s ) log P ( 0 | t , s ; θ ) .
当然对于本文的问题需要特殊化, 既然先验P ( C = 1 ) = 1 k + 1 , P ( C = 0 ) = k k + 1 P ( C = 1 ) = 1 k + 1 , P ( C = 0 ) = k k + 1 , 故
J ( θ ) := E ( t , s ) ∼ p ( t , s ) log P ( 1 | t , s ; θ ) + k ⋅ E ( t , s ) ∼ q ( t , s ) log P ( 0 | t , s ; θ ) . J ( θ ) := E ( t , s ) ∼ p ( t , s ) log P ( 1 | t , s ; θ ) + k ⋅ E ( t , s ) ∼ q ( t , s ) log P ( 0 | t , s ; θ ) .
q ( t , s ) q ( t , s ) 是一个区别于p ( t , s ) p ( t , s ) 的分布, 本文采用了p ( t ) q ( s ) p ( t ) q ( s ) .
作者进一步对前一项加了解释
P ( 1 | t , s ; θ ) = P ( t , s ) P ( C = 1 ) P ( t , s ) P ( C = 1 ) + P ( t ) P ( s ) P ( C = 0 ) ≤ P ( t , s ) k ⋅ P ( t ) P ( s ) , P ( 1 | t , s ; θ ) = P ( t , s ) P ( C = 1 ) P ( t , s ) P ( C = 1 ) + P ( t ) P ( s ) P ( C = 0 ) ≤ P ( t , s ) k ⋅ P ( t ) P ( s ) ,
故
E ( t , s ) ∼ p ( t , s ) log P ( 1 | t , s ; θ ) + log k ≤ I ( t , s ) . E ( t , s ) ∼ p ( t , s ) log P ( 1 | t , s ; θ ) + log k ≤ I ( t , s ) .
又J ( θ ) J ( θ ) 的第二项是负的, 故
J ( θ ) ≤ I ( t , s ) , J ( θ ) ≤ I ( t , s ) ,
所以最大化J ( θ ) J ( θ ) 能够一定程度上最大化t , s t , s 的互信息.
reweight
教师网络一般要求精度(干净数据集上的准确率)比较高, 但是通过对抗训练所生成的教师网络往往并不具有这一特点, 所以作者采取的做法是, 对特征t t 根据其置信度来加权w w , 最后损失为
L ( θ ) := E ( t , s ) ∼ p ( t , s ) w t log P ( 1 | t , s ; θ ) + k ⋅ E ( t , s ) ∼ p ( t ) p ( s ) w t log P ( 0 | t , s ; θ ) , L ( θ ) := E ( t , s ) ∼ p ( t , s ) w t log P ( 1 | t , s ; θ ) + k ⋅ E ( t , s ) ∼ p ( t ) p ( s ) w t log P ( 0 | t , s ; θ ) ,
其中
w t ← p y p r e d = y ( f t , t + ) ∈ [ 0 , 1 ] . w t ← p y p r e d = y ( f t , t + ) ∈ [ 0 , 1 ] .
即w t w t 为教师网络判断t + t + 类别为y y (真实类别)的概率.
拟合概率P ( 1 | t , s ; θ ) P ( 1 | t , s ; θ )
在负采样中, 这类概率是直接用逻辑斯蒂回归做的, 本文采用
P ( 1 | t , s ; θ ) = h ( t , s ) = e t T s / τ e t T s / τ + k M , P ( 1 | t , s ; θ ) = h ( t , s ) = e t T s / τ e t T s / τ + k M ,
其中M M 为数据集的样本个数.
会不会
e t T s / τ e t T s / τ + γ ⋅ k M 2 , e t T s / τ e t T s / τ + γ ⋅ k M 2 ,
把γ γ 也作为一个参数训练符合NCE呢?
实验的细节
文中有如此一段话
we sample negatives from different classes rather than different instances, when picking up a positive sample from the same class.
也就是说在实际实验中, t + , s + t + , s + 对应的类别是同一类的, t + , s − t + , s − 对应的类别不是同一类的.
In our view, adversarial examples are like hard examples supporting the decision boundaries. Without hard examples, the distilled models would certainly make mistakes. Thus, we adopt a self-supervised way to generate adversarial examples using Projected Gradient Descent (PGD).
也就是说, t , s t , s 都是对抗样本?
超参数: k = 16384 k = 16384 , τ = 0.1 τ = 0.1 .
疑问
算法中的采样都是针对单个样本的, 但是我想实际训练的时候应该还是batch的, 不然太慢了, 但是如果是batch的话, 怎么采样呢?
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix