更通用、有效,蚂蚁自研优化器WSAM入选KDD Oral

2023-10-13 12:36 402 阅读 ID:1514
机器之心
机器之心

深度神经网络(DNNs)的泛化能力与极值点的平坦程度密切相关,因此出现了 Sharpness-Aware Minimization (SAM) 算法来寻找更平坦的极值点以提高泛化能力。本文重新审视 SAM 的损失函数,提出了一种更通用、有效的方法 WSAM,通过将平坦程度作为正则化项来改善训练极值点的平坦度。通过在各种公开数据集上的实验表明,与原始优化器、SAM 及其变体相比,WSAM 在绝大多数情形都实现了更好的泛化性能。WSAM 在蚂蚁内部数字支付、数字金融等多个场景也被普遍采用并取得了显著效果。该文被 KDD '23 接收为 Oral Paper。

  • 论文地址:https://arxiv.org/pdf/2305.15817.pdf
  • 代码地址:https://github.com/intelligent-machine-learning/dlrover/tree/master/atorch/atorch/optimizers

随着深度学习技术的发展,高度过参数化的 DNNs 在 CV 和 NLP 等各种机器学习场景下取得了巨大的成功。虽然过度参数化的模型容易过拟合训练数据,但它们通常具有良好的泛化能力。泛化的奥秘受到越来越多的关注,已成为深度学习领域的热门研究课题。

最近的研究表明,泛化能力与极值点的平坦程度密切相关,即损失函数“地貌”中平坦的极值点可以实现更小的泛化误差。Sharpness-Aware Minimization (SAM) [1] 是一种用于寻找更平坦极值点的技术,是当前最有前途的技术方向之一。它广泛应用于各个领域,如 CV、NLP 和 bi-level learning,并在这些领域明显优于原先最先进的方法。

为了探索更平坦的最小值,SAM 定义损失函数 L 在 w 处的平坦程度如下:

  • 我们提出 WSAM,将平坦度视为正则化项,并在不同任务之间给予不同的权重。我们提出一个“权重解耦”技术来处理更新公式中的正则化项,旨在精确反映当前步骤的平坦度。当基础优化器不是 SGD 时,如 SGDM 和 Adam,WSAM 在形式上与 SAM 有显著差异。消融实验表明,这种技术在大多数情况下可以提升效果。
  • 我们在公开数据集上验证了 WSAM 在常见任务中的有效性。实验结果表明,与 SAM 及其变体相比,WSAM 在绝大多数情形都有着更好的泛化性能。

预备知识

方法介绍

WSAM 的设计细节

简单示例

实验

我们在各种任务上进行了实验,以验证 WSAM 的有效性

图像分类

我们首先研究了 WSAM 在 Cifar10 和 Cifar100 数据集上从零开始训练模型的效果。我们选择的模型包括 ResNet18 和WideResNet-28-10。我们使用预定义的批大小在 Cifar10 和 Cifar100 上训练模型,ResNet18 和 WideResNet-28-10 分别为 128,256。这里使用的基础优化器是动量为 0.9 的 SGDM。按照 SAM [1] 的设置,每个基础优化器跑的 epoch 数是 SAM 类优化器的两倍。我们对两种模型都进行了 400 个 epoch 的训练(SAM 类优化器为 200 个 epoch),并使用 cosine scheduler 来衰减学习率。这里我们没有使用其他高级数据增强方法,例如 cutout 和 AutoAugment。

Tab. 2 给出了在不同优化器下,ResNet18、WRN-28-10 在 Cifar10 和 Cifar100 上测试集的 top-1 错误率。相比基础优化器,SAM 类优化器显著提升了效果,同时,WSAM 又显著优于其他 SAM 类优化器。

ImageNet 上的额外训练

我们进一步在 ImageNet 数据集上使用 Data-Efficient Image Transformers 网络结构进行实验。我们恢复了一个预训练的 DeiT-base checkpoint,然后继续训练三个 epoch。模型使用批大小 256 进行训练,基础优化器为动量 0.9 的 SGDM,权重衰减系数为 1e-4,学习率为 1e-5。我们在四卡 NVIDIA A100 GPU 重复跑 5 次并计算平均误差和标准差。

标签噪声的鲁棒性

如先前的研究 [1, 4, 5] 所示,SAM 类优化器在训练集存在标签噪声时表现出良好的鲁棒性。在这里,我们将 WSAM 的鲁棒性与 SAM、ESAM 和 GSAM 进行了比较。我们在 Cifar10 数据集上训练 ResNet18 200 个 epoch,并注入对称标签噪声,噪声水平为 20%、40%、60% 和 80%。我们使用具有 0.9 动量的 SGDM 作为基础优化器,批大小为 128,学习率为 0.05,权重衰减系数为 1e-3,并使用 cosine scheduler 衰减学习率。针对每个标签噪声水平,我们在 {0.01, 0.02, 0.05, 0.1, 0.2, 0.5} 范围内对 SAM 进行网格搜索,确定通用的

值。然后,我们单独搜索其他优化器特定的超参数,以找到最优泛化性能。我们在 Tab. 5 中列出了复现我们结果所需的超参数。我们在 Tab. 6 中给出了鲁棒性测试的结果,WSAM 通常比 SAM、ESAM 和 GSAM 都具有更好的鲁棒性。

探索几何结构的影响

SAM 类优化器可以与 ASAM [4] 和 Fisher SAM [5] 等技术相结合,以自适应地调整探索邻域的形状。我们在 Cifar10 上对 WRN-28-10 进行实验,比较 SAM 和 WSAM 在分别使用自适应和 Fisher 信息方法时的表现,以了解探索区域的几何结构如何影响 SAM 类优化器的泛化性能。

消融实验

在本节中,我们进行消融实验,以深入理解 WSAM 中“权重解耦”技术的重要性。如WSAM 的设计细节所述,我们将不带“权重解耦”的 WSAM 变体(算法 4)Coupled-WSAM 与原始方法进行比较。


结果如 Tab. 8 所示。Coupled-WSAM 在大多数情况下比 SAM 产生更好的结果,WSAM 在大多数情况下进一步提升了效果,证明“权重解耦”技术的有效性。  

极值点分析

在这里,我们通过比较 WSAM 和 SAM 优化器找到的极值点之间的差异,进一步加深对 WSAM 优化器的理解。极值点处的平坦(陡峭)度可通过 Hessian 矩阵的最大特征值来描述。特征值越大,越不平坦。我们使用 Power Iteration 算法来计算这个最大特征值。


Tab. 9 显示了 SAM 和 WSAM 优化器找到的极值点之间的差异。我们发现,vanilla 优化器找到的极值点具有更小的损失值但更不平坦,而 SAM 找到的极值点具有更大的损失值但更平坦,从而改善了泛化性能。有趣的是,WSAM 找到的极值点不仅损失值比 SAM 小得多,而且平坦度十分接近 SAM。这表明,在寻找极值点的过程中,WSAM 优先确保更小的损失值,同时尽量搜寻到更平坦的区域。  

超参敏感性

视频介绍

                                                                                2分钟版本
                                                                             13分钟版本

参考文献

[1] Pierre Foret et al. Sharpness-aware Minimization for Efficiently Improving Generalization. ICLR '21.

[2] Juntang Zhuang et al. Surrogate Gap Minimization Improves Sharpness-Aware Training. ICLR '22.

[3] Jiawei Du et al. Efficient Sharpness-aware Minimization for Improved Training of Neural Networks. ICLR '22.

[4] Jungmin Kwon et al. ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks. ICML '21.

[5] Minyoung Kim et al. Fisher SAM: Information Geometry and Sharpness Aware Minimisation. ICML '22.

免责声明:作者保留权利,不代表本站立场。如想了解更多和作者有关的信息可以查看页面右侧作者信息卡片。
反馈
to-top--btn