Distilling the Knowledge in a Neural Network

Hinton G., Vinyals O. & Dean J. Distilling the Knowledge in a Neural Network. arXiv preprint arXiv 1503.02531

q1=exp(zi/T)jexp(zj/T).

主要内容

这篇文章或许重点是在迁移学习上, 一个重点就是其认为soft labels (即概率向量)比hard target (one-hot向量)含有更多的信息. 比如, 数字模型判别数字237的概率分别是0.1, 0.01, 这说明这个数字2很有可能和3长的比较像, 这是one-hot无法带来的信息.

于是乎, 现在的情况是:

  1. 以及有一个训练好的且往往效果比较好但是计量大的模型t;

  2. 我们打算用一个小的模型s去近似这个已有的模型;

  3. 策略是每个样本x, 先根据t(x)获得soft logits zRK, 其中K是类别数, 且z未经softmax.

  4. 最后我们希望根据下面的损失函数来训练s:

    L(x,y)=T2Lsoft(x,y)+λLhard(x,y)

其中

Lsoft(x,y)=i=1Kpi(x)logqi(x)=i=1Kexp(vi(x)/T)jexp(vj(x)/T)logexp(zi(x)/T)jexp(zj(x)/T)

Lhard(x,y)=logexp(zy(x))jexp(zj(x))

至于T2是怎么来的, 这是为了配平梯度的magnitude.

Lsoftzk=i=1Kpiqiqizk=1Tpki=1Kpiqi(1Tqiqk)=1T(pki=1Kpiqk)=1T(qkpk)=1T(ezi/Tjezj/Tevi/Tjevj/T).

T足够大的时候, 并假设jzj=0=jvj=0, 有

Lsoftzk1KT2(zkvk).

故需要加个T2取抵消这部分的影响.

代码

其实一直很好奇的一点是这部分代码在pytorch里是怎么实现的, 毕竟pytorch里的交叉熵是

logpy(x)

另外很恶心的一点是, 我看大家都用的是 KLDivLOSS, 但是其实现居然是:

L(x,y)=ylogyyx,

注: 这里的是逐项的.

def kl_div(x, y):
    return y * (torch.log(y) - x)


x = torch.randn(2, 3)
y = torch.randn(2, 3).abs() + 1

loss1 = F.kl_div(x, y, reduction="none")
loss2 = kl_div(x, y)

这时, 出来的结果长这样

tensor([[-1.5965,  2.2040, -0.8753],
        [ 3.9795,  0.0910,  1.0761]])
tensor([[-1.5965,  2.2040, -0.8753],
        [ 3.9795,  0.0910,  1.0761]])

又或者:

def kl_div(x, y):
    return (y * (torch.log(y) - x)).sum(dim=1).mean()


torch.manual_seed(10086)

x = torch.randn(2, 3)
y = torch.randn(2, 3).abs() + 1

loss1 = F.kl_div(x, y, reduction="batchmean")
loss2 = kl_div(x, y)

print(loss1)
print(loss2)
tensor(2.4394)
tensor(2.4394)

所以如果真要弄, 应该要

def soft_loss(z, v, T=10.):
    # z: logits
    # v: targets
    z = F.log_softmax(z / T, dim=1)
    v = F.softmax(v / T, dim=1)
    return F.kl_div(z, v, reduction="batchmean")
posted @   馒头and花卷  阅读(228)  评论(0编辑  收藏  举报
编辑推荐:
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
阅读排行:
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
点击右上角即可分享
微信分享提示