论文链接:
https://arxiv.org/abs/2211.16231
开源代码:
https://github.com/zhengli97/CTKD
一、 背景问题
目前已有的蒸馏方法中,都会采用带有温度超参的KL Divergence Loss进行计算,从而在教师模型和学生模型之间进行蒸馏,公式如下:
那么这就带来了两个问题:
- 不同的教师学生模型在KD过程中最优超参不一定是4。如果要找到这个最佳超参,需要进行暴力搜索,会带来大量的计算,整个过程非常低效。
- 一直保持静态固定的温度超参对学生模型来说不是最优的。基于课程学习的思想,人类在学习过程中都是由简单到困难的学习知识。那么在蒸馏的过程中,我们也会希望模型一开始蒸馏是让学生容易学习的,然后难度再增加。难度是一直动态变化的。
于是一个自然而然的想法就冒了出来:
在蒸馏任务里,能不能让网络自己学习一个适合的动态温度超参进行蒸馏,并且参考课程学习,形成一个蒸馏难度由易到难的情况?
于是我们就提出了CTKD来实现这个想法。
二、方法
CTKD的论文的结构图如下:
CTKD方法可以简单分为左右两个部分:
两种方案的对比图如Fig.2所示。
2. 难度逐渐增加的课程学习部分。
三、实验结果
三个数据集:CIFAR-100,ImageNet和MS-COCO。
CIFAR-100上,CTKD的实验结果:
作为一个即插即用的插件,应用在已有的SOTA方法上:
在ImageNet上的实验:
在MS-COCO的detection实验上:
温度超参的整体学习过程可视化:
由以上图可以看到,CTKD整体的动态学习τ的过程。
将CTKD应用在多种现有的蒸馏方案上,可以取得广泛的提升效果。
四、总结
本文提出了一种基于动态温度超参的蒸馏新框架CTKD,在学生模型学习的过程中,可学习的温度超参被训练去以对抗的方式最大化蒸馏损失。通过可学习温度超参,CTKD将蒸馏组织成了一个由易到难的任务,取得了明显的提升。同时该方法可以作为即插即用的插件,应用在已有的SOTA方法上带来广泛的提升效果。
作者:李政
来源:知乎【https://zhuanlan.zhihu.com/p/595735843】