域泛化问题 (Domain Generalization Problem) 一直是被广泛关注的领域,人们希望通过研究 DG 问题从而让模型像人类一样,很好地识别常见分布之外的视觉数据。当前的 DG 方法,主要集中在设计特定的损失函数或者梯度优化方式,如域对齐 (Invariant Learning)、元学习 (Meta Learning) 等,这些方法有很好的理论保障及充分的分析,但遗憾在于其效果并不足够显著。
在 ICLR 2023 上,来自南洋理工大学-商汤科技联合研究中心、香港科技大学、加拿大蒙特利尔大学和 NVIDIA 的研究者提出从神经网络结构的角度探索不同网络结构的泛化能力,即通用专家混合 (GMoE) 模型,用于进一步提升 Vision Transformer 在 Domain generalization 上的性能。实验结果表明 GMoE 相对于目前最先进的方法,取得了优异的 DG 性能结果。
论文链接:
https://openreview.net/forum?id=RecZ9nB9Q4
代码链接:
https://github.com/Luodian/Generalizable-Mixture-of-Experts
视频解读:
https://www.bilibili.com/video/BV1jV4y1C7h8
一、动机与背景
域泛化是指在一个新的、未知的领域或环境中,一个模型可以在没有特定的领域或环境知识的情况下进行有效的预测。
域泛化在现实生活中具有很高的应用价值。例如,在医疗诊断中,由于医疗数据难以获取,模型必须在不同的医院、城市或国家之间进行泛化,以便有效地进行诊断;在自动驾驶汽车领域,模型必须能够在各种天气、路况和道路类型等不同环境下泛化,以实现可靠的自动驾驶。
因此,域泛化是一个重要的研究方向,可以使机器学习模型在更广泛的实际应用场景中实现更好的性能。
目前,域泛化方法主要可以分为以下几类:
- 基于数据增强的方法:该方法通过对训练数据进行不同的增强操作,如旋转、平移、缩放等,增加训练数据的多样性,以提高模型的泛化能力
- 基于特征对齐的方法:该方法通过对源域和目标域的特征进行对齐,减小不同域之间的分布差异,以提高模型的泛化能力
- 基于元学习的方法:该方法通过在训练过程中学习如何快速适应新的领域,以提高模型的泛化能力
- 基于集成学习的方法:该方法通过组合多个不同的模型或训练过程,提高模型的泛化能力
以上所提到的方法各有优劣,其中主要的缺陷包括:
- 基于数据增强的方法可能会导致过拟合,因为增强操作可能会使模型过度关注一些特定的特征
- 基于特征对齐的方法需要对源域和目标域的数据进行对齐,但在实际应用中,源域和目标域之间的分布差异可能非常大,导致对齐效果不佳
- 基于元学习的方法需要大量的元训练数据,并且可能会导致过拟合,因为元学习的目标是在训练过程中快速适应新的领域,而不是在整个训练集上获得最佳性能
- 基于集成学习的方法需要组合多个模型或训练过程,这可能会导致计算成本较高,并且可能需要更多的训练数据来训练多个模型
在认识到现在的预适应方法所存在的问题后,我们认为有必要从一个新的角度去思考如何更好地解决这个问题。
最近出现的 Vision Transformers 在视觉领域的各个任务中逐渐代替了 CNN,成为被广泛采用的网络结构。因此,我们认为网络结构和泛化性之间可能存在着密不可分的联系。
在机器学习中,归纳偏置是指在模型选择和学习算法中使用的先验知识和假设,它们可以帮助模型从数据中学习有用的模式,而不仅是记住特定的训练实例。一个好的归纳偏置可以帮助模型更快地收敛,更准确地泛化到新数据,以及更好地抵抗过拟合。
不同的网络结构可以提供不同的归纳偏置、不同的能力来表示数据的特征。例如,卷积神经网络(CNN)在图像领域的应用中表现出色,是因为CNN结构天生适合处理图像中的局部性和平移不变性。类似地,循环神经网络(RNN)适合处理序列数据,因为它们具有自然的时间归纳偏置。
目前已经有一些相关工作提出了理论工具[1,2],用于分析神经网络结构在解决不同问题时的能力强弱。然而,目前这些分析仍然存在于 In-distribution learning problem 中,而我们的问题则更关注于 Out-of-distribution learning problem。因此,我们对[1]中提出的 algorithmic alignment 在 DG 问题上进行了延伸分析。
二 、方法介绍
承接上述分析,我们推测一个好的网络结构可能更容易在数据中学习到更适用于域泛化的特征。接下来我们借助 Algorithmic Alignment 工具,从这个推测出发,在理论上一步步进行分析。
首先我们简单介绍 Algorithmic Alignment,它通过衡量神经网络结构与目标函数之间的相似性表征独立同分布(IID)推理任务的易处理程度(Easiness)。
Algorithmic Alignment被正式定义为以下内容。
接下来,我们在 DG 中定义了一些关键概念。目标函数是训练集和测试集之间的不变关系。为了简单起见,我们假设标签是无噪声的。
借助以上的定义,我们可以将算法对齐从独立同分布泛化(IID generalization)扩展到域泛化(DG)问题上。
Theorem 1 表明,与不变关系对齐的网络更能够抵抗分布的变化。我们可以用实验检验,不同类型的网络的泛化能力强弱。
我们在 DomainBed 上首先测试使用 ERM 训练的 ViT 的性能,结果如图1(a)所示。令人惊讶的是,在使用了更少参数的情况下,使用 ERM 训练的 ViT 在几个数据集上已经优于使用 SOTA DG 算法的 ResNet-50。这表明在 DG 中,选择骨干网络结构可能比损失函数更为重要。
我们可以发现,如果神经网络结构与不变关系(invariant correlation)对齐,ERM足以实现良好的性能。在 OfficeHome 或 DomainNet 的某些领域中,形状属性与标签之间存在不变关系,如图1(b)所示。
相反,属性纹理和标签之间存在虚假相关性(spurious correlation)。根据[3]的分析,多头注意力(MHA)是具有形状偏置的低通滤波器,而卷积是具有纹理偏置的高通滤波器。因此,仅使用 ERM 训练的 ViT 就可以胜过使用 SOTA DG 算法训练的 CNN。
进一步地,我们也很好奇如何提高 ViT 的泛化能力?Theorem 1 建议我们应该利用不变关系的特性。
在图像识别中,一个物体通常由不同部分组成(例如,我们可以用视觉属性来组合性的描述一个物体[4])。在真实世界的图像数据中,标签依赖于多个属性。对于 DG 而言,捕捉多样的视觉属性特别重要。例如,牛津词典中对大象的定义是“一种拥有厚厚的灰色皮肤、大耳朵、两个称为象牙的弯曲外齿和一个称为象鼻的长鼻子的大型动物”。
那么,应该如何捕捉这些视觉特征呢?这些视觉特征又是如何决定一个物体的类别的呢?
条件语句(即编程中的 IF/ELSE),如算法1所示,在 DG 问题里,可以被试做根据视觉属性的组合,在不同域中判断一个物体的类别的工具。
假设我们在 DomainNet 上训练网络以识别大象,如图1(b)的第一行所示。对于不同领域的大象,形状和纹理差异显著,而视觉属性(大耳朵、弯曲的牙齿、长鼻子)在所有领域中都是不变的。借助条件语句,对大象的识别可以表述为“如果一只动物有大耳朵、两个弯曲的外齿和一个长鼻子,那么它就是一只大象”。然后子任务是识别这些视觉属性,这也需要条件语句。
通过 Theorem 2,我们证明了一个基于 ViT 结构的多 Experts 的 Mixture-of-Experts 网络结构,可以很好地在 Algorithmic Alignment 框架下对齐 IF-ELSE 语句。通过执行 IF-ELSE 语句,能够很好地捕捉到一个物体的不同区域的特征(如大象的大耳朵、弯曲的牙齿、长鼻子)。我们也基于前人在 MoE 方向的探索[5,6],提出了我们的 Generalizable Mixture-of-Experts (GMoE)。其结构如下:
三 、实验结果
我们在 Table 1 中提供了 train-validation selection 的结果,其中包括 baselines、最新的 SOTA DG 方法以及使用 ERM 训练的 GMoE。
结果表明,GMoE-S/16 即使在没有 DG 算法的情况下,已经在几乎所有数据集上表现优于以前基于 ResNet-50-S/16 的 DG 方法。
GMoE 的泛化能力来自于其内部骨干网络结构,这与现有的 DG 算法是正交的。这意味着 SOTA DG 算法可以应用于改进 GMoE 的性能。
为了验证这个想法,我们应用了两个 SOTA DG 算法改进 GMoE,其中一个是修改损失函数的方法(FISH),另一个是采用模型集成的方法(SWAD)。Table 2 的结果表明,采用 GMoE,相比于 ResNet-50,显著提高了这些已有 DG 方法的性能。
我们同样在限定了基础模型结构的 IID 性能(ViT-S/16和 ResNet-50 V2) 基础上,比较这两个模型的 DG 性能。以下是对比结果,可以看到 ViT-S/16 在略输 ResNet-50 V2 的情况下,仍然在 DG 任务上取得了更好的性能。
以下是 GMoE 的 Expert Selection 可视化结果。图像来自于 CUB-DG 中自然领域的不同类别。图中不同颜色的线连接不同图像上的同一类别鸟类的视觉属性(Visual Attributes)。同一视觉属性由同一 Expert 处理,例如嘴和尾巴由 Expert 3 处理,左/右腿由 Expert 4 处理。
参考文献
[1] Xu, Keyulu, et al. “What can neural networks reason about?.” ICLR 2020 (Spotlight)
[2] Xu, Keyulu, et al. “How neural networks extrapolate: From feedforward to graph neural networks.” ICLR 2021 (Oral)
[3] Namuk Park and Songkuk Kim. How do vision transformers work? ICLR 2022 (Spotlight)
[4] Object detectors emerge in deep scene cnns. ICLR 2015
[5] Riquelme, Carlos, et al. “Scaling vision with sparse mixture of experts.” NeurIPS 2021
[6] Chi, Zewen, et al. “On the representation collapse of sparse mixture of experts.” NeurIPS 2022
公众号:商汤学术