NeurIPS 2022 | 知识蒸馏中如何让“大教师网络”也教得好?

2022-12-29 12:38 886 阅读 ID:642
机器学习算法与自然语言处理
机器学习算法与自然语言处理

本文分为以下几个部分对该工作进行介绍:

  • 文章链接
  • 代码链接
  • 研究背景
  • 提出的方法
  • 实验效果
  • 投稿历程

论文题目:

Asymmetric Temperature Scaling Makes Larger Networks Teach Well Again

文章来源:

NeurIPS 2022

论文链接:

https://arxiv.org/abs/2210.04427

代码链接:

https://github.com/lxcnju/ATS-LargeKD

https://gitee.com/lxcnju/ats-mindspore

作者主页:

http://www.lamda.nju.edu.cn/lixc/

1.『研究背景』

知识蒸馏(Knowledge Distillation)可以将大(强)模型的能力传递给轻量(弱)模型,其基本形式如下:

                                                                    ▲ 知识蒸馏示意图

其基本步骤为:1)在训练集上训练一个大教师网络,或者拿现有的当作教师网络;2)使用图示的损失去指导学生网络进行训练。损失包括两部分:正常分类的损失和知识蒸馏损失。前者是 hard-label,后者是 soft-label。引入后者的目的是因为学生直接学习 hard-label 太困难了,因此期望学生能够模仿教师的 soft 输出,从而把握类别之间的相似度,从而更好地学习。

值得注意的是:知识蒸馏损失里面的温度系数 Temperature 很重要!如果τ很小,那么教师的输出结果像 hard-label,导致和正常分类损失相比没有什么额外的信息;如果τ很大,那么教师的输出结果像 uniform-label,类别之间的差异性就没有了,仅仅起到了一个 label smoothing 的作用。

                                                              ▲ 知识蒸馏中温度系数的作用

普遍的认知是越好的教师教学生教地越好。然而实际上,2019 年有学者 [Jang Hyun Cho, 2019] 指出:大神经网络不一定教地好!

引用下面的一个示意图(来自 [Seyed Iman Mirzadeh, 2020]),随着 teacher size 逐渐变大,教师的准确率越来越高(红色的 teacher accuracy),但是其教的学生的准确率先变高再变低(蓝色的 student accuracy)。现有的工作都将这个奇怪的现象归因于”大教师网络“和”小学生网络“之间的容量差异(capacity gap),但是没有形象地指出这种差异为何出现。

                                                               ▲ 大神经网络不一定教得好

因此,本文的研究内容就是:为什么大神经网络不一定教地好,有没有什么简单的办法让大神经网络教地好

                                                                       ▲ 文章的研究内容

2.『提出的方法』

本文最直接的猜测起源于下面的式子:

                                  ▲ 大教师网络和小教师网络在教同一个学生网络的区别

也就是说,在遍历所有可能温度系数的情况下,相比较于大教师网络,小教师网络更容易给出质量更好的指导信息,即

首先,文章通过一些观察实验发现:大教师网络更容易给出置信度较高的预测,包括两个方面。其一,正确类别的 logit 可能更大;其二,错误类别 logits 之间差异更小。本文称神经网络最后一层给出的类别预测得分称为 logits。

                                         ▲ 大神经网络和小神经网络的输出 logits 的分布

具体地,在 CIFAR-100 和 CIFAR-10 上训练 ResNet14/44/110 和 WRN28-1/4/8,统计神经网络输出的 logits 的如下指标:

                                                       ▲ 大教师和小教师给出的logits示例

这就是最基本的现象,也是整个工作的启发点:大神经网络更为置信,给出的 target logit 更大,或者 wrong logits 差异更小!

那么我们不妨设想两个极端:

                                                                 ▲ 根据现象得到的猜测

因此本文的猜测为:大神经网络不能教地好的原因是无论使用怎样的温度系数,都难以使得错误类别概率“错落有致”。

为了从理论上去推导验证,本文将知识蒸馏分为三个部分:

                                                                      ▲ 知识蒸馏分解

分别包括:1)Correct Guidance,类似于 hard-label 的 one-hot 标签;2)Smooth Regularization,错误类别的平均概率值,类似于 label smoothing;3)Class Discriminability,错误类别之间的差异,可以用方差来度量,错误类别差异越大,教师提供的指导信息越多!

接下来是理论分析,先定义一些符号和公式:

                                                                    ▲ 一些基本的符号

  理论分析:

事实上:随着τ不断增大,得到的p的熵越来越大,即越来越均匀。本研究证明了:随着τ不断增大,得到的p元素之间的方差也会越来越小。

在正确类别 logit 最大情况下:随着τ不断增大,错误类别概率的均值

逐渐增大。

最重要的等式为:

其中 DA、IV、DV 分别解释如下:

  • Inherent Variance:错误类别 logits 经过 softmax 之后得到的类别概率分布的方差;
  • Derived Average:所有类别 logits 经过 softmax 之后得到的错误类别概率的平均值;
  • Derived Variance:所有类别 logits 经过 softmax 之后得到的错误类别概率的方差。

针对某一个样本的计算示意图如下(SF 代表 Softmax):

                                                          ▲ DA、IV、DV关系示意图

利用该公式解释为什么大教师网络教不好:

翻译为中文为:

  • 大教师网络给的正确类别 logit 的值很大,导致 DA 变小;
  • 大教室网络给的错误类别 logit 的差异很小,导致 IV 变小。

最终都会导致 DV 变小,即:大教师网络的 DV 很小,传统温度缩放下很难让错误类别的概率“错落有致”

提出的方法为 Asymmetric Temperature Scaling(ATS),针对正确/错误类别施加较大/较小的温度系数

结论:ATS 可以使得大教师网络的 DV 变大让错误类别的概率“错落有致”

3.『实验结果』

实验设置和结果就不详细介绍了,有兴趣的可以看文章。下面就简单贴一下结果:

4.『投稿历程』

到此,本文的基本方法都介绍完了,是一个非常简单的改进。研究设计的过程中也充满了乐趣,主要包括三个过程:

  1. 发现大神级网络和小神经网络输出的结果具有一些差异,兴奋值 ++;
  2. 发现可以将知识蒸馏的损失分解为三部分,特别是 class discriminability 的定义很有意思,兴奋值 ++;
  3. 发现可以用公式解释大教师神经网络的 DV 很小,兴奋值 ++;
  4. 发现可以提出一个非常简单的 ATS 来使得大教师教地更好,兴奋值 ++。

该工作完成于 2021.1 月份左右,在新年前几天完成的,满怀期待投稿了 ICML 2022。很不幸的是被拒了,个人感觉是在边缘,因为审稿人给的意见都没有特别严重的,主要是一些行文思路和概念没有解释清楚。于是完善了之后转投了 NeurIPS,得分为 2 (Strong Reject),5(Borderline Accept),6 (Weak Accept)。看到审稿意见本想放弃,但仔细一看给 2 分的貌似只是针对我们公式符号的不合理性进行了攻击,感觉还是有希望的。于是修改了符号,提交了 rebuttal revision。审稿人然后就将分数改为 6。最终得分为 666。

免责声明:作者保留权利,不代表本站立场。如想了解更多和作者有关的信息可以查看页面右侧作者信息卡片。
反馈
to-top--btn