机器学习模型通常在训练数据中对于少数群体表现较差。然而,我们对导致子群体偏移(subpopulation shift)的机制变化以及算法在如此多样化的转变中的泛化情况了解甚少。在这项工作中,我们对子群体偏移进行了细致的分析。我们首先提出了一个统一的框架,用于解释子群体中常见的偏移情况。然后,我们在视觉、语言和医疗领域的 12 个真实世界数据集上建立了一个全面的基准,评估了 20 种最先进的算法。
通过训练超过 10,000 个模型并获得结果,我们揭示了这一领域未来发展的有趣观察结果。目前代码,数据,和模型已经在 GitHub 上开源。
论文名称:
Change is Hard: A Closer Look at Subpopulation Shift
论文链接:
https://arxiv.org/abs/2302.12254
项目主页:
https://subpopbench.csail.mit.edu/
代码链接:
https://github.com/YyzHarry/SubpopBench
Talk链接:
https://www.youtube.com/watch?v=WiSrCWAAUNI
一、研究背景与动机
机器学习模型在训练数据存在不均衡的情况下,通常会在少数群体(subgroup)上表现不佳。例如,在识别牛和骆驼的任务 [1](cow-camel problem)中,我们知道牛通常出现在有绿草的地区,而骆驼则经常出现在有黄沙背景的地方。然而,这种相关性是虚假的,因为牛或骆驼的存在与背景颜色无关。因此,经过训练的模型在上述图像上表现良好,但无法推广到在训练数据中罕见且具有不同背景颜色的动物,例如沙滩上的牛或草地上的骆驼。
此外,当涉及到医学诊断时,研究发现机器学习模型在年龄或种族上代表性不足的群体上表现通常较差,引发了重要的公平性问题。所有这些问题通常被泛化的称为子群体偏移问题,但我们对导致子群体偏移的机制变化以及算法在如此多样化的偏移中的泛化情况了解甚少。在这项工作中,我们对子群体偏移进行了细致的分析。
二、子群体偏移的统一框架
为了对 subpopulation shift 进行建模,我们提出了一个统一的框架。在经典的分类问题的设置中,我们有来自多个类别的训练数据,并且每个类别中的样本数量可能不同。然而,在涉及到子群体存在偏移的情况下,除了类别之外,还存在属性(attribute)上的偏移,例如牛和骆驼分类问题中的背景颜色。在这种情况下,我们可以根据属性和标签的组合定义离散的子群体,而在同一类别中,不同属性的样本数量也可能不同。
而为了测试模型的性能,我们需要在所有子群体上进行测试,以确保在所有子群体中最差的性能足够好,或者确保在所有群体中性能相同且良好。其中最为广泛应用的指标是 worst group accuracy(WGA)[2],顾名思义,即是在 group=(label x attribute)里的最差准确率要足够好。
为了提供一个通用的数学形式,受到最近泛化的长尾分类问题的启发 [3],我们使用贝叶斯定理重写分类模型,并将其分解为三个项。这种建模方法解释了在子群体偏移下属性和类别是如何影响结果的。
最后,某些属性在训练中可能完全缺失,但在测试中对于某些类别是存在的,这促使我们需要考虑属性泛化,Attribute generalization。这四种情况构成了最基本的子群体偏移的组成部分,它们是解释实际数据中复杂子群体偏移的重要因素。每种偏移所产生的原因,以及其对分类模型的影响,由下表归纳所示。
需要注意的是,这四种偏移仅仅是最基础的字群体偏移类型;而在实际问题中,数据集通常同时包含多种类型的偏移,而不仅仅是一种。我们在文章中提出了一些能够量化数据集中每种偏移程度的指标。
三、SubpopBench: 子群体偏移的 Benchmark
那么现在,在建立了子群体偏移的建模和细化分类之后,我们提出了 SubpopBench,一个包含 12 个真实世界数据集上评估最先进的 20 多种算法的综合基准 benchmark。具体而言,这些数据集来自各种模态和任务,包括视觉 [4][5][6][7][8][9]、语言 [10][11] 和医疗 [12][13] 的应用,而数据模态涵盖自然图像、文本、临床文本和胸部 X 射线等多种形式。不同数据集还展示了不同的子群体偏移的成分。
这里具体细节就不多赘述了,详细内容还请参考我们的文章。那么通过建立这一基准并使用 20 多种最先进的算法训练了超过 10,000 个模型,我们揭示了对未来研究有启示的一些观察结果。
四、对于子群体偏移的细致分析
4.1 SOTA算法只在某些特定类型的偏移上改善了子群体的鲁棒性
首先,我们观察到目前最先进的算法只在某些类型的数据偏移上改善了子群体的鲁棒性,而在其他类型的数据偏移上并未改善。我们在这里绘制了各种最先进算法相对于 ERM 的最差子群体准确率的提升情况。对于 spurious correlation 和 class imbalance 而言,现有算法能够提供一致的相对于 ERM 的最差子群体增益,表明在解决这两种特定的数据偏移问题上已经取得了进展。
然而有趣的是,在 attribute imbalance 问题上,算法对于不同数据集的改善都很小。此外,在属性泛化 attribute generalization 方面,其性能甚至变得更差。这些发现强调了目前的进展仅针对特定类型的数据偏移,对于更具挑战性的属性泛化等偏移类型尚未取得进展。
4.2 Representation & classifier 在子群体偏移中的作用
此外,我们还探索了网络学习到的表示,即 representation,和分类器,即 classifier,在子群体偏移中的作用。具体来说,我们将整个网络分为两个部分:特征提取器 f 和分类器 g。其中,f 从输入中提取潜在特征,g 输出最终的预测结果。那么,表示和分类器如何影响子群体的性能?
首先,在基于 ERM 模型的基础上,当仅优化分类器而保持表示不变时,其可以显著提高 spurious correlation 和 class imbalance 情况下的性能,这表明 ERM 学习到的表示对于这两种偏移已经足够好了。
有趣的是,改进表示学习而不是分类器,则可以显著提高 attribute imbalance 问题的性能,这表明我们可能需要更强大的特征来应对某些特定的偏移。最后,没有任何学习方式能够在 attribute generalization 下带来性能提升。这凸显了在面对不同类型的偏移时,我们需要考虑模型不同组件设计的重要性。
4.3 模型选择和属性可用性对子群体偏移评估的影响
此外,我们观察到模型选择和属性可用性对子群体偏移评估有着相当大的影响。具体来说,当逐渐删除训练以及验证数据中的属性标注时,所有算法都经历了显著的性能下降,特别是当完全不知道训练和验证数据中的属性时。这表明,在子群体偏移中,对是否提前知道属性仍然是取得良好性能的条件,而未来的算法应该考虑更加真实的模型选择和属性可用性场景,例如完全不知道训练和验证数据中的属性时如何进行泛化。
4.4. 评估指标之间的非常根本的 tradeoff 关系
最后,我们揭示了评估指标之间的非常根本的 tradeoff 关系。最差子群体准确率,Worst-group accuracy,WGA,被认为是子群体评估的黄金标准。然而,提高 WGA 是否总能改善其他有意义的指标呢? 首先,我们展示了提高 WGA 可以导致某些指标的性能改善,比如这里展示的 adjusted accuracy,即平衡的准确率。
然而,如果我们进一步考虑最坏情况下的精确度,worst case precision,令人惊讶的是,它与 WGA 呈现出强烈的负线性相关性。这揭示了仅使用 WGA 来评估模型在子群体转变中的性能存在了根本的限制:一个在 WGA 上表现良好的模型可能具有很低的最坏情况下的精确度,而低精确度在关键应用领域如医学诊断中,是尤其令人担忧的。我们的观察强调了在子群体偏移中需要更加现实的评估指标。在论文中,我们还展示了许多其他与 WGA 呈负相关的指标。
五、结语
最后总结一下本文,针对子群体偏移 - subpopulation shift - 这个实际的问题,我们提出了一个全面的数学建模框架,一个涵盖了多种模态的 benchmark,并进行了细致的分析,得到了许多有趣的结果。这篇博客也仅仅是大体介绍了我们的研究,而具体细节还请大家直接读我们的文章。
当然,我们的工作还是存在其局限性,也留了一些坑。希望本文能抛砖引玉,也非常欢迎大家 follow 我们的工作!如果大家有任何想要交流的技术问题,欢迎留言多多交流。
参考文献
[1] Understanding the Failure Modes of Out-of-Distribution Generalization.
[2] Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization.
[3] Invariant feature learning for generalized long-tailed classification.
[4] The Caltech-UCSD Birds-200-2011 Dataset
[5] Large-scale CelebFaces Attributes (CelebA) Dataset
[6] MetaShift: A Dataset of Datasets for Evaluating Contextual Distribution Shifts and Training Conflicts
[7] NICO++: Towards Better Benchmarking for Domain Generalization
[8] Noise or Signal: The Role of Image Backgrounds in Object Recognition
[9] BREEDS: Benchmarks for Subpopulation Shift
[10] Nuanced Metrics for Measuring Unintended Bias with Real Data for Text Classification
[11] A Broad-Coverage Challenge Corpus for Sentence Understanding through Inference
[12] CheXpert: A Large Chest Radiograph Dataset with Uncertainty Labels and Expert Comparison
[13] MIMIC-CXR, a de-identified publicly available database of chest radiographs with free-text reports
文章来源:知乎
文章链接:
https://zhuanlan.zhihu.com/p/642511026
Illustration From IconScout By Delesign Graphics