CVPR 2022 | 重用教师网络分类器实现知识蒸馏

CVPR 2022 | 重用教师网络分类器实现知识蒸馏

  知识蒸馏(KD)致力于将性能好但消耗大的重型网络压缩成轻量化网络。

  为了弥补教师-学生网络的性能差,需要想办法对齐教师网络和学生网络在相同输入下的输出能力。

  近些年的方法基本上都基于教师网络中间层,利用中间层特征给学生网络额外的监督或者精心设计知识表征方法。

  问题在于,无论是高效的知识表征方法还是优化好的超参,都难以保证轻易成功应用于实践。并且经转化的知识多样,导致学生网络的性能提升缺少一个统一且清晰的解释。

  为了解决上述问题,浙江大学和浙江工业大学的研究者提出了一个重用教师网络分类器的简易知识蒸馏技术(SimKD),该技术用单一的L2损失一个特征对齐部分训练学生网络,最后在不需要精心调整超参,不需要平衡不同损失权重的情况下,得到了性能相当的轻量级学生网络。

  上图是CIFAR-100的测试图可视化结果,从100个类别随机采样10个类别。
  从教师网络中提取的特征用较重的颜色表示,从学生网络中提取的特征用浅颜色表示。
  左图是一般的KD技术的结果,右图是本文的SimKD技术的结果。
  可以看出,SimKD技术可以使学生网络具有和教师网络几乎一致的分类能力。

论文链接:
http://arxiv.org/abs/2203.14001
代码:
None

一、本文方法

   一般的知识蒸馏可以看成两个部分的堆栈:由多个非线性层组成的特征编码器、由一个全连接层和softmax激活函数组成的分类器

  这两个部分通过后传播算法进行端到端训练。不同的知识蒸馏方法的不同点主要在于梯度如何表示以及从何处开始计算。

   图a表示原始的知识蒸馏,计算类预测中的梯度,以此更新整个学生网络的参数;图b表示特征蒸馏,相比与原始的蒸馏方法,其通过多样的知识表征,收集更多中间层的梯度信息;图c表示本文的SimKD技术,分类器前面的层中计算L2损失,后传播该梯度只用于更新学生网络的特征编码器和维度投影部分。

1.1 分类器复用

  学生网络直接借用预训练教师网络的分类器部分而没有重新训练一个,这样就不需要标签信息来计算交叉熵损失,使得特征对齐损失成为产生梯度的唯一源。

  这么做的原因是,本文作者认为分类器位于网络的深层,其包含的分类信息举足轻重,但是在原来的知识蒸馏方法中被忽视了,所以尝试直接复用这些层。

1.2 映射器

  想要在学生网络直接复用教师网络分类器,则需要将学生网络中提取的特征维度跟教师网络的对齐。

  对此,本文作者设计了一个映射器,代价是损失了3%以内的剪枝率。对其之后计算L2损失:

二、实验分析
2.1 SimKD技术的整体效果

  表1至3列出了基于十五个网络组合的各种方法的全面性能比较,其中教师和学生模型与类似或完全不同的架构进行了实例化。


2.2 分类器复用操作分析

  为了证实复用分类器操作对于蒸馏效果的关键作用,本文作者基于两种不同的策略来训练学生网络的特征编码器和分类器,展开多组实验。

2.2.1 联合训练

  像特征蒸馏一样,同时训练学生网络的特征编码器和对应的分类器,损失函数设定如下:

  效果如下:

  上图显示了同时训练学生网络的特征编码器和分类器,分别利用这样训练出来的学生网络分类器和教师网络分类器来测试准确率的效果图。

  为了彻底评估联合训练效果,使用了四种不同的教师学生组合以及四组统一的α值。

  可以看出,无论是使用自己的分类器还是复用教师网络分类器,联合训练下学生网络性能都远差于单独的SimKD技术的性能,可见,联合训练的方式很难让学生网络学习到教师网络的重要判别信息。这部分实验证明了特征编码和分类器训练不应该一起进行。

2.2.2 顺序训练

  从头重新训练一个分类器,对比SimKD技术,效果如下:

  从表中可以看出,重新训练一个分类器效果并不稳定,在有些情况下无论怎么调整超参,效果仍然不好。

  这样做要费时费力调整出好的参数,显然不如SimKD技术,省时省力且效果稳定。

  上述和这部分两次实验证明了直接复用教师网络分类器,单独训练特征编码器部分是本文蒸馏技术成功的关键。

2.2.3 复用教师网络的更多深度层

  当然,理论上不仅可以复用分类器,应该是复用越多层性能越靠近教师网络,实验也证明了这一点:

  不过复用层数越多,参数量也越多,实际上,仅仅复用最后的分类器部分已经可以很好地模型的平衡性能和复杂度了。

2.2.4 映射器分析

  上述为了复用分类器,设计了一个用来对其特征维度的映射器,它可以实现学生网络特征精准对齐教师网络特征,便于后面复用分类器,但是同时会带来额外的参数,进而影响模型剪枝率。

  上图展示了本文技术下的不同教师-学生组合中的学生模型的剪枝率和模型性能之间的关系,灰色虚线是对应的原始蒸馏技术的剪枝率。SimKD中由于映射器的原因,为了维持高性能,或多或少需要损失一定的剪枝率。

  上图是所有模型组合实验统计出来的模型剪枝率损失直方图,损失在3%以内,也就是可以控制在可接受范围内,甚至有一些情况下剪枝率会增加。
  调整映射器的设置,最终效果也不同,统计如下:

2.2.5 应用-多教师网络

  图中AVEG相当于原始知识蒸馏的变体,即取多教师网络预测的均值;AEKD采用给多教师网络分配权重策略;AEKD-F是AEKD的变体,不同之处在于合并中间特征;SimKDv 表示映射器改为通过全连接层实现特征对齐的SimKD。

2.2.6 应用-无数据知识蒸馏

  SimKD可以通过将其KD训练步骤替换为“重复分类器”操作和相应的特征对齐来轻松地集成到这些现有方法中,并使其性能提升。

三、总结

  本文提出一个简易知识蒸馏技术,将预训练的教师网络的分类器部分复用到学生网络中,同时设计一个映射器采用L2损失训练,以对齐教师和学生网络的特征维度。

  该技术在实验结果上看,确实达到简单又有效的蒸馏效果。

posted @ 2022-05-21 12:04  万国码aaa  阅读(832)  评论(0编辑  收藏  举报