三分钟理解知识蒸馏
知识蒸馏的意义
能够压缩模型,提升模型性能
为什么能够压缩模型?
!!!谁知道了告诉我一下!!!
为什么能提升模型精度?
栗子:分类问题有三个分类:猫,狗,乌龟,实际训练过程中,比如当前的数据真实标签是:猫,模型预测出猫,狗,乌龟的概率分别是0.6, 0.3, 0.1,
传统思路:不错,识别对了,猫的概率最高,给模型一定的奖励;
知识蒸馏:不错,识别对了,猫的概率最高,并且狗比乌龟更像猫,给模型一定的奖励;
总结:即便是负样本,也包含大量知识,知识蒸馏能把这部分知识也学习起来。
大致步骤:
1. 基于一个已经训练好的NET-T模型,该模型经过大量数据的训练准确度很高,但是模型笨重,将NET-T模型最终softmax结果进行软化,生成soft-target,继而生成loss1;
2. 创造一个轻量模型NET-S正常前像传播,实际标签用one-hot向量表示即hard-target,生成loss2;
3. 将loss1与loss2加权求和生成loss3;
4. loss3用于更新NET-S网络;
即将NET-T模型的知识迁移到NET-S上并优化性能