本文旨在深入浅出的介绍图上的分布外泛化问题(一个最近刚火的研究方向)与基于(因果)不变性原理的求解思路,对相关领域研究者提供easy-to-follow的讲解。本文内容主要基于今年年初笔者发表于ICLR‘22的论文《Handling Distribution Shifts on Graphs: An Invariance Perspective》。
这项工作首次对图上的节点级任务的分布外泛化问题给出了一般化定义,并基于不变性原理给出了有理论保障的解决思路。文末还会简单介绍笔者合作参与的三个刚被NeurIPS‘22接收的相关工作,并讨论可以进一步探索的方向。
论文地址:
https://arxiv.org/abs/2202.02466
项目地址:
https://github.com/qitianwu/GraphOOD-EERM
论文解读-知乎:
https://zhuanlan.zhihu.com/p/580112987
图机器学习目前依然是炙手可热的研究领域,但不少的已有方向都遇到了瓶颈期。本文将要重点介绍的分布外泛化问题(Out-of-distribution Generalization,简称OOD泛化)也为图学习引入了一个新的子赛道,与现有的很多场景和设定都存在可能的交叉,目前有很大的研究空间。
一、为什么要考虑分布外泛化的问题?
如何提高在新数据(例如未知分布或未见实体)上的泛化性能是机器学习的一个核心问题。我们知道一般的学习问题都是在一个训练集上完成模型训练,而后模型需要在一个新的测试集上给出结果。机器学习问题的误差可以被大致分解为两部分:
其中表征误差(反映了模型拟合训练数据的能力)是由模型的表达能力/容量决定的,而泛化误差则由在训练集与测试集模型表现的差异决定。当我们采用较为复杂的模型结构(例如神经网络)与有效的优化算法,可以大大降低表征误差。但是当测试数据分布与训练分布呈现明显不同时,模型的泛化误差则很难被控制。这样的场景在实际中也很常见,比如在线下数据进行训练的推荐模型需要泛化到线上的真实场景,在模拟场景下训练的驾驶器要泛化到具有真实交互的环境中。这就是分布外泛化要解决的核心问题:如何利用有限观测的数据,学习一个稳健的模型,能够泛化到与训练分布有明显差异的测试数据上。
二、图上的节点级分布外泛化的挑战
目前大部分关于分布外泛化问题的研究集中在欧式数据(如图片),而对于图结构数据的相关研究还较少。与普通欧式数据不同的是,图结构数据上节点级预测任务的分布偏移问题需要解决两个核心的技术挑战。
- 样本互连性: 由于节点的互连特性,数据样本通常是非独立同分布的,这就为数据分布的建模带来了困难。下图给出了一个简单示意,对于图片数据我们可以把生成每张图片的分布看作相同且独立的;然而对于图结构数据,每个节点的生成依赖于邻居节点,数据分布不能被看作独立的。
- 图结构信息: 除了节点特征外,图的结构也蕴含了重要的信息,会影响到表示学习和预测任务。因此,在考虑数据分布建模与模型泛化的时候,也需要挖掘结构信息的特征并兼顾其影响。
三、从问题定义出发
然而上述的定义方式不方便对每个节点进行分析,因此下面我们考虑一种以节点为单位的定义。
四、基于不变性原理的分布外泛化
直接解决上述的问题是非常困难的,因为模型在没有结构性假设和对学习任务的先验知识的情况下往往是不可能实现分布外泛化的(没有免费的午餐)。为此,本文从数据生成的角度,通过利用数据背后的因果不变性[1,2,3],来引导模型学习到可以实现泛化的映射关系。
受以上思路的启发,我们把学习目标定义为在不同环境上对应风险损失的均值和方差:
这里定义
β是一个权重超参数。这一目标的直观考虑是如果模型在不同的环境下能够给出相近的结果(即loss方差最小化),其学到的从x到y的映射就是相对环境不变的。这也有别于传统的监督学习方法Empirical Risk Minimization(ERM),即只对每个样本的loss的均值进行优化,这种情况下模型就很容易学到与环境相关的映射,在训练数据上发生过拟合(对于ERM的局限性分析感兴趣的读者,可以进一步参考我们的论文)。
然而,上式则要求训练数据中包含来自多个环境的观测数据,并且每个数据样本对应的环境id也是已知的。对于图结构数据,尤其是节点级任务,这两个要求都是不满足的。通常情况下,训练数据只包含了一整张大图,也没有足够的每个节点对应哪个环境的信息。为了解决这一困难,我们引入 K 个额外的数据生成器
基于输入图生成 K 份不同的图数据
来探索环境,模拟来自不同环境的观测数据。基于此,我们考虑如下的双层优化学习目标:
这里我们定义每个图数据所对应的损失函数
针对数据生成器
我们将其参数化为一个图结构编辑器(graph editor),即将每一条连边假设为自由参数,对输入图进行局部改变(删除或增加连边)。具体的,我们将每一个改变视为动作(action),最终使用基于策略梯度的REINFORCE算法进行优化,以解决离散动作空间采样不可导的问题。我们将本文提出的方法称为Explore-to-Extrapolation Risk Minimization(EERM),下图给出了训练过程的数据流图。
理论分析
为了证明提出方法的有效性,该工作给出了几点理论分析。这里将主要结论整理如下(对此部分感兴趣的读者可以阅读论文):
- 提出的方法EERM可以引导模型产生的预测分布学习到稳定的从输入特征到标签的映射关系,从而在理论上保证取得理想的分布外泛化问题的最优解(由(1)式给出)。
- 当模型给出的节点表示在训练集和测试集上具有相同的表达能力(具体量化为输入与输出包含在表示向量中的信息),本文提出的EERM可以降低测试分布上的泛化误差上界。
实验结果
为了进一步验证提出的方法,我们需要设计实验,测试模型在不同数据分布上的性能。真实的图数据中可能包含多种不同的分布偏移,这里我们考虑三种情况:人造混淆噪声(Artificial Transformation)、跨图领域迁移(Cross-Domain Transfer)、动态图时序泛化(Temporal Evolution)。下表展示了本文使用的6个数据集以及对应的分布偏移的形式。
处理人造混淆噪声 我们首先考虑Cora和Amazon-Photo数据集,对其引入噪声,方法如下:采用两个随机初始化的GCN,第一个GCN基于原始节点特征生成节点真实标签,第二个GCN基于节点标签和环境id生成冗余特征,于是节点的特征为原始特征和冗余特征的拼接。对每个数据集,我们将环境id设为1-10,总共生成10张图,第一张用于训练,第二张验证,其余的作为测试。如此下来,训练集与测试集之间就被引入了分布偏移,原始特征与标签的关系是对于环境不变的,而冗余特征与标签的关系则是环境敏感的。我们考虑使用GCN作为预测模型主干,下图分别显示了使用传统方法(Empirical Risk Minimization,ERM,即直接优化训练数据的损失)与本文提出方法(EERM)在Cora和Amazon数据集上8个测试图的准确率(Accuracy)对比。这里,我们重复了20次实验(使用不同网络初始化),展示了准确率的分布情况。可以看到,EERM在绝大多数情况下好于ERM。
跨图领域泛化 一种典型的分布外泛化场景是图数据上的领域泛化(Domain Generalization)。这里我们考虑Twitch-Explicit和Facebook-100数据集,它们都是社交网络,分别包含了7张和100张子图。我们使用一部分图作为训练集,另一部分作为测试。由于每一张子图都是来自不同地区的社交网络,而且大小、密度、标签分布都不尽相同,因此训练数据与测试数据就天然存在分布偏移。
对于Twitch数据集,我们使用子图DE作为训练集,ENGB作为验证集,其余作为测试集。由于是二分类问题且类别标签不均衡,所以我们使用ROC-AUC作为评测指标。下图显示了分别使用GCN、GAT、GCNII作为网络主干,ERM与EERM在5个测试图上的性能对比。可以看到,EERM在大部分情况下都超越了ERM。
对于Facebook数据集,我们考虑使用多个图进行训练。具体的,我们考虑三种训练子图的组合。下表显示了使用不同训练子图的组合,在三个测试图(Penn,Brown,Texas)上的准确率对比。同样,EERM在绝大部分情况下超越了ERM。
动态图时序外推 另一种典型的分布偏移来源于时序动态图,训练数据往往是历史某个阶段收集的片段,测试数据则来源于未来。随着时间的推移,图数据可能发生变化。这里我们进一步考虑两种不同的情况。第一种情况对应动态的时序snapshot,我们考虑Elliptic数据集,它一共包含49个graph snapshot,每一个记录了在一段时间内的金融交易,任务是识别网络中的非法节点。我们把snapshot按时间顺序排列,使用前5个作为训练,第6-10个作验证,其余的作为测试集(把每相邻的4个合并为一组)。我们使用F1分数作为评测指标,下图显示了使用GraphSAGE和GPRGNN作为主干模型的效果对比。可以看到,EERM显著好于ERM,取得了平均9.6%/10.0%的提升。
接着我们考虑第二种情况,随着时间的推移,图中的节点和连边会发生变化。这里我们考虑OGBN-Arxiv数据集,其中每个节点是论文。我们按论文的发表时间将节点分为训练集和测试集。为了引入分布偏移,我们扩大训练节点和测试节点的时间间隔:使用2011前发表的论文作为训练集,2011-2014年发表的论文作为验证集,2014年之后的为测试集。下表展示了时间在2014-2016/2016-2018/2018-2020年的测试节点上的测试准确率。可以看到,随着时间的推移(分布偏移进一步扩大),模型的性能都呈现下降趋势,但ERM的下降趋势更为明显。这也说明,EERM能够有效提升模型对分布偏移的鲁棒性。
五、讨论与展望
图级别预测与节点级预测的联系与区别 近期也有不少工作关注神经网络在图结构数据上的分布外泛化/外推问题,例如[4, 5, 6]。然而,他们主要专注于整图级别(graph-level)任务,有别于本文主要关注的节点级(node-level)任务。整图级别任务与节点级任务所关注的重点与技术难点是不同的。
对于图级别任务的分布外泛化问题,可以采用如下定义(将式(1)修改为)
这里
表示一个图样本(如分子图),分类器f以每张图作为输入预测它的标签(例如分子的性质)。希望进一步了解如何利用不变性原理求解整图级别任务(分子图预测)下的分布外泛化问题,可以阅读笔者参与的一篇刚被NeurIPS22接收的论文[7],相关方法也在OGB和DrugOOD标准benchmark上取得了SOTA效果。
更深层的图数据生成过程 图数据本身包含的结构拓扑信息是其不同于一般欧式数据的特性之一,因此在考虑图上的分布偏移时也需要对观测数据背后隐含拓扑进行考虑和建模。比如一个常见的场景是训练和测试在不同的图数据上,模型训练的图是一个完整的观测图,测试时的图拓扑发生了改变(例如节点、连边变化)。笔者参与的另一个NeurIPS22论文[8]就对此类图拓扑发生偏移的情形下如何提升模型泛化能力进行了探索,主要思路是从热传导过程与图神经网络的等价关系出发,挖掘拓扑背后的几何特性,引导模型学习对图的观测结构变化保持不变的映射关系。
分布外数据的判别 另一个与本文高度相关的问题是如何对分布外数据进行识别或检测。在本文所讨论的问题设定下,分布外数据只出现在了测试阶段。而现实中分布外数据也可能存在于训练集中,一个需要解决的问题就是如何识别与训练主体数据(分布内数据)有明显差异的分布外数据,帮助提升模型可靠性。针对这种情形,另一个NeurIPS22的工作[9]从数据生成过程出发提出了一个统一框架处理两个问题:
1. 如何识别训练集中的分布外数据;
2. 如何判别测试阶段模型未见的分布外数据。
OOD判别与OOD泛化本身存在很多交集,也期待后续更多的工作对其进行补充和探索。
参考文献
[1] Mateo Rojas-Carulla, et al. Invariant models for causal transfer learning. In Journal of Machine Learning Research (JMLR), 2018.
[2] Martín Arjovsky, et al. Invariant risk minimization. CoRR, abs/1907.02893, 2019.
[3] Peter Bühlmann. Invariance, causality and robustness. CoRR, abs/1812.08233, 2018.
[4] Keyulu Xu, et al. How neural networks extrapolate: From feedforward to graph neural networks. In International Conference on Learning Representations (ICLR), 2021.
[5] Beatrice Bevilacqua, et al. Size-invariant graph representations for graph classification extrapolations. In International Conference on Machine Learning (ICML), 2021.
[6] Haoyang Li et al. OOD-GNN: Out-of-Distribution Generalized Graph Neural Network. In Transactions on Knowledge and Data Engineering (TKDE), 2022.
[7] Nianzu Yang, et al, Learning Substructure Invariance for Out-of-Distribution Molecular Representations. In Advances in Neural Information Processing Systems (NeurIPS), 2022.
[8] Chenxiao Yang, et al. Geometric Knowledge Distillation: Topology Compression for Graph Neural Networks. In Advances in Neural Information Processing Systems (NeurIPS), 2022.
[9] Zenan Li, et al. GraphDE: A Generative Framework for Debiased Learning and Out-of-Distribution Detection on Graphs. In Advances in Neural Information Processing Systems (NeurIPS), 2022.
作者:吴齐天
文章来源:知乎文章【https://zhuanlan.zhihu.com/p/580112987】
Illustration by IconScout Store from IconScout