本文分为以下几个部分对该工作进行介绍:
- 文章链接
- 代码链接
- 研究背景
- 提出的方法
- 实验效果
- 投稿历程
论文题目:
Asymmetric Temperature Scaling Makes Larger Networks Teach Well Again
文章来源:
NeurIPS 2022
论文链接:
https://arxiv.org/abs/2210.04427
代码链接:
https://github.com/lxcnju/ATS-LargeKD
https://gitee.com/lxcnju/ats-mindspore
作者主页:
http://www.lamda.nju.edu.cn/lixc/
1.『研究背景』
知识蒸馏(Knowledge Distillation)可以将大(强)模型的能力传递给轻量(弱)模型,其基本形式如下:
其基本步骤为:1)在训练集上训练一个大教师网络,或者拿现有的当作教师网络;2)使用图示的损失去指导学生网络进行训练。损失包括两部分:正常分类的损失和知识蒸馏损失。前者是 hard-label,后者是 soft-label。引入后者的目的是因为学生直接学习 hard-label 太困难了,因此期望学生能够模仿教师的 soft 输出,从而把握类别之间的相似度,从而更好地学习。
值得注意的是:知识蒸馏损失里面的温度系数 Temperature 很重要!如果τ很小,那么教师的输出结果像 hard-label,导致和正常分类损失相比没有什么额外的信息;如果τ很大,那么教师的输出结果像 uniform-label,类别之间的差异性就没有了,仅仅起到了一个 label smoothing 的作用。
普遍的认知是越好的教师教学生教地越好。然而实际上,2019 年有学者 [Jang Hyun Cho, 2019] 指出:大神经网络不一定教地好!
引用下面的一个示意图(来自 [Seyed Iman Mirzadeh, 2020]),随着 teacher size 逐渐变大,教师的准确率越来越高(红色的 teacher accuracy),但是其教的学生的准确率先变高再变低(蓝色的 student accuracy)。现有的工作都将这个奇怪的现象归因于”大教师网络“和”小学生网络“之间的容量差异(capacity gap),但是没有形象地指出这种差异为何出现。
因此,本文的研究内容就是:为什么大神经网络不一定教地好,有没有什么简单的办法让大神经网络教地好?
2.『提出的方法』
本文最直接的猜测起源于下面的式子:
也就是说,在遍历所有可能温度系数的情况下,相比较于大教师网络,小教师网络更容易给出质量更好的指导信息,即
首先,文章通过一些观察实验发现:大教师网络更容易给出置信度较高的预测,包括两个方面。其一,正确类别的 logit 可能更大;其二,错误类别 logits 之间差异更小。本文称神经网络最后一层给出的类别预测得分称为 logits。
具体地,在 CIFAR-100 和 CIFAR-10 上训练 ResNet14/44/110 和 WRN28-1/4/8,统计神经网络输出的 logits 的如下指标:
这就是最基本的现象,也是整个工作的启发点:大神经网络更为置信,给出的 target logit 更大,或者 wrong logits 差异更小!
那么我们不妨设想两个极端:
因此本文的猜测为:大神经网络不能教地好的原因是无论使用怎样的温度系数,都难以使得错误类别概率“错落有致”。
为了从理论上去推导验证,本文将知识蒸馏分为三个部分:
分别包括:1)Correct Guidance,类似于 hard-label 的 one-hot 标签;2)Smooth Regularization,错误类别的平均概率值,类似于 label smoothing;3)Class Discriminability,错误类别之间的差异,可以用方差来度量,错误类别差异越大,教师提供的指导信息越多!
接下来是理论分析,先定义一些符号和公式:
理论分析:
事实上:随着τ不断增大,得到的p的熵越来越大,即越来越均匀。本研究证明了:随着τ不断增大,得到的p元素之间的方差也会越来越小。
在正确类别 logit 最大情况下:随着τ不断增大,错误类别概率的均值
逐渐增大。
最重要的等式为:
其中 DA、IV、DV 分别解释如下:
- Inherent Variance:错误类别 logits 经过 softmax 之后得到的类别概率分布的方差;
- Derived Average:所有类别 logits 经过 softmax 之后得到的错误类别概率的平均值;
- Derived Variance:所有类别 logits 经过 softmax 之后得到的错误类别概率的方差。
针对某一个样本的计算示意图如下(SF 代表 Softmax):
利用该公式解释为什么大教师网络教不好:
翻译为中文为:
- 大教师网络给的正确类别 logit 的值很大,导致 DA 变小;
- 大教室网络给的错误类别 logit 的差异很小,导致 IV 变小。
最终都会导致 DV 变小,即:大教师网络的 DV 很小,传统温度缩放下很难让错误类别的概率“错落有致”。
提出的方法为 Asymmetric Temperature Scaling(ATS),针对正确/错误类别施加较大/较小的温度系数:
结论:ATS 可以使得大教师网络的 DV 变大让错误类别的概率“错落有致”。
3.『实验结果』
实验设置和结果就不详细介绍了,有兴趣的可以看文章。下面就简单贴一下结果:
4.『投稿历程』
到此,本文的基本方法都介绍完了,是一个非常简单的改进。研究设计的过程中也充满了乐趣,主要包括三个过程:
- 发现大神级网络和小神经网络输出的结果具有一些差异,兴奋值 ++;
- 发现可以将知识蒸馏的损失分解为三部分,特别是 class discriminability 的定义很有意思,兴奋值 ++;
- 发现可以用公式解释大教师神经网络的 DV 很小,兴奋值 ++;
- 发现可以提出一个非常简单的 ATS 来使得大教师教地更好,兴奋值 ++。
该工作完成于 2021.1 月份左右,在新年前几天完成的,满怀期待投稿了 ICML 2022。很不幸的是被拒了,个人感觉是在边缘,因为审稿人给的意见都没有特别严重的,主要是一些行文思路和概念没有解释清楚。于是完善了之后转投了 NeurIPS,得分为 2 (Strong Reject),5(Borderline Accept),6 (Weak Accept)。看到审稿意见本想放弃,但仔细一看给 2 分的貌似只是针对我们公式符号的不合理性进行了攻击,感觉还是有希望的。于是修改了符号,提交了 rebuttal revision。审稿人然后就将分数改为 6。最终得分为 666。