TPAMI 2023 | 图神经网络在分布外图上的泛化

2023-12-08 16:41 513 阅读 ID:1689
将门
将门

图神经网络的分布外泛化能力决定了其在实际应用中的稳定性,是近年来的研究热点,该论文的初始版本于2021年11月放于arXiv(https://arxiv.org/abs/2111.10657),是早期将因果方法与图神经网络结合解决图分布外泛化问题的文章之一。本文将介绍图分布外泛化的关键问题、解决方法、以及未来研究工作。

论文标题: Generalizing Graph Neural Networks on Out-Of-Distribution Graphs 

论文链接: http://www.shichuan.org/doc/157.pdf 论文代码:https://github.com/googlebaba/StableGNN

一、引言

目前提出的图神经网络(GNN)方法没有考虑训练图和测试图之间的不可知偏差,从而导致GNN在分布外(OOD)图上的泛化性能变差。导致GNN方法泛化性能下降的根本原因是这些方法都是基于IID假设。在此条件下,GNN模型倾向于利用图数据中的虚假相关进行预测。但是,这样的虚假相关可能在未知的测试环境中改变,从而导致GNN的性能下降。因此,消除虚假相关的影响对于实现稳定的GNN模型至关重要。

为了实现此目的,在本文中,我们强调对于图级别任务虚假相关存在于子图级别单元,并且用因果视角来分析GNN模型性能下降的原因。基于因果视角的分析,我们提出了一个统一的因果表示框架用于稳定GNN模型,称之为StableGNN。这个框架的主要思想是首先利用一个可微分的图池化层提取图的高层语义特征,然后借助因果分析的区分能力来帮助模型摆脱虚假相关的影响。

因此,GNN模型可以更加专注于有区分性的子结构和标签之间的真实相关性。我们在具有不同偏差程度的仿真数据和8个真实的OOD图数据上验证了我们方法的有效性。此外,可解释性实验也验证了StableGNN可以利用因果结构做预测。

本质上说,对于一般的机器学习方法,当遭受分布偏移问题时,准确率下降的根本原因是不相关特征和类别标签之间的虚假相关导致的。这个虚假相关根本上是由不相关特征和相关特征的意外的相关性导致的。而对于本文研究的图级别任务,由于图的性质通常由子图单元决定 (比如,在分子图中,原子和化学键团表示其功能单元),所以我们定义一个子图单元可以是一个对于标签相关的或者不相关的特征单元。

如图1所示,以“房子”模体分类任务为例,其中图的标签表示一个图中是否有“房子”模体。GCN模型是在“房子”模体和“星星”模体高度相关的训练图上训练的。在这个数据上,“房子”模体和“星星”模体将会高度相关。这个意料之外的相关性将会导致“星星”模体的结构特征和“房子”标签的虚假相关。图1的第二列展示了用于GCN预测的最重要的子图可视化结果 (由GNNExplainer产生)。由结果可知,GNN倾向于利用星星模体做预测。然而当遭遇没有“星星”模体的图,或者其他模体(比如,"钻石"模体)和星星模体在一起时, GCN模型被证明容易产生错误的结果。

                                                                      图1

为了去除虚假相关对于GNN模型泛化性的影响,我们提出了一个新颖的用于图的因果表示框架,称之为StableGNN, 其结合了GNN模型灵活的表示学习和因果学习方法对于区分虚假相关能力的两方面优势。对于表示学习部分,我们提出了一个图高层语义学习模块,其利用了一个图池化层来映射相近的节点为簇,其中每一个簇为原始图中一个紧密连接的子图单元。

此外,我们理论证明了不同图的簇的语义含义可以通过一个有序的拼接操作实现匹配。给定了匹配的高层语义变量,我们用因果视角分析GNN的性能退化并且提出了一个新颖的因果变量区分正则化项通过学习一套样本权重来去除每一个高维变量对之间的相关性。这两个模块在我们的模型中联合训练。此外,如图1所示,StableGNN可以有效的排除不相关子图的影响(“星星”模体)并且利用真实的相关子图("房子"模体)做预测。

二、方法

所提出框架的基本想法是设计一个因果表示学习方法来抽取有意义的图高层语义变量然后估计他们对于图级别任务的真实因果效应。如图2所示,所提出的模型框架主要分为两个部分:图高层语义表示学习模块和因果变量区分模块。

                                                                  图2 StableGNN的模型框架

2.1 图高层变量学习

我们通过GNN层的置换不变性,证明了在经过DiffPool层之后,不同图之间对应的高层语义是匹配的:

2.2 因果变量区分正则化项

到目前为止变量学习部分学习的变量可能是虚假相关,在本节中,我们首先分析以因果视角分析导致GNN模型性能下降的原因,然后提出一个因果变量区分正则化器(CVD)。

以因果视角重视GNN

在这种情景下,目前的GNN方法不能准确的评估子图的因果效应,因此GNN的性能可能会衰减。混淆平衡技术通常被用来评估变量的因果效应,但是他们通常针对某一变量是由单个维度的特征组成的数据,我们要处理的数据是是多个高维变量组成的,因此,我们提出一种多变量多维度的混淆变量平衡技术,如图3b所示:

                                                                         图 3 GNN的因果视角

重加权HSIC

但是上述的混淆变量平衡技术主要针对的是二元处理变量,我们需要处理的高维处理变量。基于混淆平衡技术主要目的是去除处理变量和混淆变量之间的关联,我以我们考虑采用HSIC来度量高维变量之间的关联,同时提出采用样本加权的方式去除高维变量之间的关联,方法如下:对于两个变量U和V,我们首先采用随机初始化的样本权重来重加权它们:

然后我们可以得到加权的HSIC:

对于去除所有变量之间的相关性,我们优化如下的全局高维变量去相关项:

2.3 加权的GNN模型

三、实验

我们分别在仿真数据上和真实数据上验证了我们的实验效果。

3.1 仿真实验

我们通过控制“房子”模体和“星星”模体的相关性程度,生成了{0.6,0.7,0.8,0.9}四种偏差程度不同的训练数据,更多生成数据的细节请参考原文。我们分别以GCN/GraphSAGE作为基模型实现了我们的模型,所以本节主要和相应的基模型进行了对比。

实验结果如表1所示。首先,相较于基模型我们都取得了比较大的提升效果,证明了我们是个有效的框架。其次,在偏差程度越大的数据上提升效果越明显,证明了我们的方法可以有效对抗数据偏移产生的分布外效果下降的问题。最后,我们的模型相较于GCN/GraphSAGE都有明显的提升,证明了我们的方法是一个灵活的框架可以提升现有模型的效果。图4是一些可解释性的例子,也能很好的说明我们的模型可以利用因果结构进行预测。

                                                              图4 GCN和StableGCN的可解释性例子

3.2 真实数据实验

我们在OGB的七个分子图性质预测的数据上展开实验,与常用的数据不同的是,这些数据都采用scaffold splitting 从而使得具有不同结构的图数据划分到训练集和测试集。

此外,我们还采用了常用的MUTAG数据集用于解释我们的结果。表2是数据集的统计信息。表3是实验结果。从表中可以看出,我们的方法综合性能排在前两位,远远大于排名第3的方法,证明了现有GNN方法在OOD场景下的图预测任务上表现的都不好而我们的方法可以取得较好的结果。同时在不同类型数据,不同的任务的数据集上我们都取得了较好的效果,证明了我们的方法是一个通用的框架。

图5是MUTAG数据集上的可解释性实验。蓝色,绿色,红色和黄色分别代表N,H,O,C原子。由GNNExplainer产生的最重要的子图被标为黑色。StableGNN正确的确定了功能团NO2和NH2,这些功能团被认为是对Mutagenic 有决定性作用的,而其他方法不能找到有解释性的子图做预测。

                                                             图 5 MUTAG数据集上的可解释性实验

四、结论和未来工作

在本文中,我们首次研究了图数据在OOD上的泛化问题。我们以因果视角分析了这个问题,认为子图之间的虚假相关会影响模型的泛化性。为了提高现有模型的稳定性,我们提出一个一般化的因果表示学习框架,称之为StableGNN,其有效的结合图高层表示学习和因果效果评估到一个统一的框架里。丰富的实验很好的验证了StableGNN的有效性,灵活性,和可解释性。

此外,我们认为本文开启了一个在图数据上进行因果表示学习的方向。本文的最重要的贡献是提出了一个通用的因果表示框架:图高层变量表示学习和因果变量区分,这两个部分都可以为任务而特殊的设计。比如,多通道的滤波器可以被用来学习图上的不同的信号到子空间里。然后对于一些数据也许在高层变量之间存在这更复杂的因果结构,因此发现这些因果结构对于重构原始数据生成过程将会更有效。

引用

[1] Shaohua Fan, Xiao Wang, Chuan Shi, Peng Cui, Bai Wang. Generalizing Graph Neural Networks on Out-Of-Distribution Graphs. IEEE TPAMI 2023

[2]R. Ying, D. Bourgeois, J. You, M. Zitnik, and J. Leskovec, “Gnnex-plainer: Generating explanations for graph neural networks,” NeurIPS, 2019.

[3] X. Zhang, P. Cui, R. Xu, L. Zhou, Y. He, and Z. Shen, “Deep stablelearning for out-of-distribution generalization,” CVPR, 2021, pp.5372–5382

[4]R. Ying, J. You, C. Morris, X. Ren, W. L. Hamilton, and J. Leskovec,“Hierarchical graph representation learning with differentiablepooling,” NeurIPS, 2018.

[5] B. Schölkopf, F. Locatello, S. Bauer, N. R. Ke, N. Kalchbrenner,A. Goyal, and Y. Bengio, “Toward causal representation learning,”Proceedings of the IEEE, vol. 109, no. 5, pp. 612–634, 2021.

Illustration From IconScout By Manypixels Gallery

免责声明:作者保留权利,不代表本站立场。如想了解更多和作者有关的信息可以查看页面右侧作者信息卡片。
反馈
to-top--btn