ICLR 2023 Oral | DIFFormer:从扩散过程出发,实现物理启发的Transformer设计

2023-06-11 20:23 509 阅读 ID:1140
将门
将门

本文介绍一项近期的研究工作,试图建立能量约束扩散微分方程与神经网络架构的联系,从而原创性的提出了物理启发下的Transformer,称作DIFFormer。作为一种通用的可以灵活高效的学习样本间隐含依赖关系的编码器架构,DIFFormer在各类任务上都展现了强大潜力。这项工作已被International Conference on Learning Representations (ICLR2023)接收,并在首轮评审就收到了四位审稿人给出的10/8/8/6评分(最终均分排名位于前0.5%)。

论文地址:

https://arxiv.org/pdf/2301.09474.pdf

项目地址:

https://github.com/qitianwu/DIFFormer

一、简介

如何得到有效的样本表征是机器学习领域的一大核心基础问题,也是深度学习范式在各类下游任务能发挥作用的重要前提。传统的表征学习方法通常假设每个输入样本是独立的,即分别将每个样本输入进encoder网络得到其在隐空间中的表征,每个样本的前向计算过程互不干扰。然而这一假设通常与现实物理世界中数据的生成过程是违背的:由于显式的物理连接或隐含的交互关系,每个观测样本之间可能存在相互的依赖。

这一观察也启发了我们去重新思考用于表征计算的encoder网络设计:是否能设计一种新型的encoder网络能够在前向计算中显式的利用样本间的依赖关系(尽管这些依赖关系是未被观察到的)。 在这个工作中,我们从两个物理学原理出发,将神经网络计算样本表征的前向过程看作给定初始状态的扩散过程,且随着时间的推移(层数加深)系统的整体能量不断下降(见下图)。

DIFFormer模型主要思想的示意图:将模型计算样本表征的前向过程看作一个扩散过程,随着时间的推移,节点之间存在信号传递,且任意节点对之间信号传递的速率会随着时间适应性的变化,使得系统整体的能量最小化。通过扩散过程和能量约束,最终的样本表征能够吸收个体和全局的信息,更有助于下游任务

通过试图建立扩散微分方程与神经网络架构的联系,我们阐释了能量约束扩散过程与各类信息传递网络(如MLP/GNN/Transformers)的联系,并为新的信息传递设计提供了一种理论参考。基于此,我们提出了一种新型的可扩展Transformer模型,称为DIFFormer(diffusion-based Transformers)。它可以作为一种通用的encoder,在前向计算中利用样本间隐含的依赖关系。大量实验表明在小/大图节点分类、图片/文本分类、时空预测等多个领域的实验任务上DIFFormer都展现了强大的应用潜力。在计算效率上,DIFFormer只需要3GB显存就可以实现十万级样本间全联接的信息传递。

二 、动机与背景

我们首先回顾一个经典的热力学中的热传导过程:假设系统中有N个节点,每个节点有初始的温度,两两节点之间都存在信号流动,随着时间的推移节点的温度会不断更新。上述物理过程事实上可以类比的看作深度神经网络计算样本表征(embedding)的前向过程。

将神经网络的前向计算过程看作一个扩散过程:每个样本视为流形上的固定位置节点,样本的表征为节点的信号,表征的更新视作节点信号的改变,样本间的信息传递看作节点之间的信号流动

这里我们可以把每个样本看作一个离散空间中的节点,样本的表征看作节点的信号。当模型结构考虑样本交互时(如信息传递),它可以被看作节点之间的信号流动,随着模型层数加深(即时间的推移),样本表征会不断被更新。

三 、扩散过程的描述

一个经典的扩散过程可以由一个热传导方程(带初始条件的偏微分方程)来描述

在离散空间中,梯度算子可以看作两两节点的信号差异,散度算子可以看作单个节点流出信号的总和,而扩散率(diffusivity)是一种对任意两两节点间信号流动速率的度量

由此我们可以写出描述N个节点每时每刻状态更新的扩散微分方程,它描述了每个状态下系统中每个节点信号的变化等于流向其他节点的信号总和:

由扩散方程导出的信息传递 我们进一步使用数值有限差分(具体的这里使用显式欧拉法)将上述的微分方程展开成迭代更新的形式,引入一个步长τ对连续时间进行离散化(再经过方程左右重新整理):

下图概述了这三种信息传递模式:

我们研究最后一种信息传递方式,每层更新的样本表征会利用上一层所有其他样本的表征,在理论上模型的表达能力是最强的。但由此产生的一个问题是:要如何才能确定合适的每层任意两两节点之间的diffusivity,使得模型能够产生理想的样本表征?

基于这一理论结果,我们进而提出了扩散过程诱导下的Transformer结构,即DIFFormer,它的每一层更新公式表示为:

两种模型DIFFormer-s和DIFFormer-a每层更新的运算过程(矩阵形式),红色标注的矩阵乘法操作是计算瓶颈。DIFFormer-s的优势在于可以实现对样本数量N的线性复杂度,有利于模型扩展到大规模数据集

模型扩展 更进一步的,我们可以引入更多设计来提升模型的适用性和灵活度。上述的模型主要考虑了样本间的all-pair attention。对于输入数据本身就含有样本间图结构的情况,我们可以加入现有图神经网络(GNN)中常用的传播矩阵(propagation matrix)来融合已知的图结构信息,从而定义每层的样本表征更新如下

  类似其他Transformer一样,在每层更新中我们可以加入residual link,layer normalization,以及非线性激活。下图展示了DIFFormer的单层更新过程。

DIFFormer的全局输入包含样本输入特征X以及可能存在的图结构A(可以省略),通过堆叠DIFFormer layer更新计算样本表征。在每层更新时,需要计算一个全局attention(具体的可以使用DIFFormer-s和DIFFormer-a两种实现),如果考虑输入图结构则加入GCN Conv

另一个值得探讨的问题,是如何处理大规模数据集(尤其是包含大量样本的数据集,此时考虑全局all-pair attention非常耗费资源)。在这种情况下我们默认使用线性复杂度的DIFFormer-s的架构,并且可以在每个训练epoch对数据集进行random mini-batch划分。由于线性复杂度,我们可以使用较大的batch size也能使得模型在单卡上进行训练(详见实验部分)。

对于包含大量样本的数据集,我们可以对样本进行随机mini-batch划分,每次只输入一个batch的样本。当输入包含图结构时,我们可以只提取batch内部样本所组成的子图输入进网络。由于DIFFormer-s只需要对batch size的线性复杂度,在实际中就可以使用较大的batch size,保证充足的全局信息

四、实验结果

为了验证DIFFormer的有效性和在不同场景下的适用性,我们考虑了多个实验场景,包括不同规模图上的节点分类、半监督图片/文本分类和时空预测任务。

图节点分类实验 此时输入数据是一张图,图中的每个节点是一个样本(包含特征和标签),目标是利用节点特征和图结构来预测节点的标签。我们首先考虑小规模图的实验,此时可以将一整图输入DIFFormer。相比于同类模型例如GNN,DIFFormer的优势在于可以不受限于输入图,学习未被观测到的连边关系,从而更好的捕捉长距离依赖和潜在关系。下图展示了与SOTA方法的对比结果。

进一步的我们考虑在大规模图上的实验。此时由于图的规模过大,无法将一整图直接输入模型(否则将造成GPU过载),我们使用mini-batch训练。具体的,在每个epoch,随机的将所有节点分为相同大小的mini-batch。每次只将一个mini-batch的节点输入进网络;而对于输入图,只使用包含在这个mini-batch内部的节点所组成的子图输入进网络;每次迭代过程中,DIFFormer也只会在mini-batch内部的节点之间学习all-pair attention。这样做就能大大减小空间消耗。又因为DIFFormer-s的计算复杂度关于batch size是线性的,这就允许我们使用很大的batch size进行训练。

下图显示了在ogbn-proteins和pokec两个大图数据集上的测试性能,其中对于proteins/pokec我们分别使用了10K/100K的batch size。特别的,对于包含百万级节点的数据集pokec,DIFFormer只消耗了3GB的GPU显存。此外,下图的表格也展示了batch size对模型性能的影响,可以看到,当使用较大batch size时,模型性能是非常稳定的。

图片/文本分类实验 第二个场景我们考虑一般的分类问题,输入是一些独立的样本(如图片、文本),样本间没有已观测到的依赖关系。此时尽管没有输入图结构,DIFFormer仍然可以学习隐含在数据中的样本依赖关系。对于对比方法GCN/GAT,由于依赖于输入图,我们这里使用K近邻人工构造一个样本间的图结构。

时空预测 进一步的我们考虑时空预测任务,此时模型需要根据历史的观测图片段(包含上一时刻节点标签和图结构)来预测下一时刻的节点标签。这里我们横向对比了DIFFormer-s/DIFFormer-a在使用输入图和不使用输入图(w/o g)时的性能,发现在不少情况下不使用输入图模型反而能给出的较高预测精度。这也说明了在这类任务中,给定的观测图结构可能是不可靠的,而DIFFormer则可以通过从数据中学习依赖关系得到更有用的结构信息。

五 、扩散过程下的统一视角

从能量约束的扩散过程出发,我们也可以将其他信息传递模型如MLP/GCN/GAT看作DIFFormer的特殊形式,从而给出统一的形式化定义。下图概括了几种方法对应的能量函数和扩散率。相比之下,从扩散过程来看,DIFFormer会考虑任意两两节点之间的信号流动且流动的速率会随着时间适应性的变化,而GNN则是将信号流动限制在一部分节点对之间。从能量约束来看,DIFFormer会同时考虑局部(与自身状态)和全局(与其他节点)的一致性约束,而MLP/GNN则是分别侧重于二者之一,且GNN通常只考虑输入图中相邻的节点对约束。

六 、总结与讨论

在这个工作中,我们讨论了如何从扩散方程出发得到MLP/GNN/Transformer的模型更新公式,而后提出了一个能量约束下的扩散过程,并通过理论分析得到了最优扩散率的闭式解。基于理论结果,我们提出了DIFFormer。总的来说,DIFFormer主要具有以下两点优势:

  • 从设计思想上看: 模型结构从能量下降扩散过程的角度导出,相比于直接的启发式设计更加具有理论依据;
  • 从模型实现上看:在保留了学习每层所有节点全局all-pair attention的表达能力的同时,DIFFormer-s只需要O(N)复杂度来更新N个节点的表征,同时兼容mini-batch training,可以有效扩展到大规模数据集

DIFFormer作为一个通用的encoder,可以被主要应用于以下几种场景:

最后欢迎感兴趣的朋友们阅读论文和访问我们的Github,共同学习进步~

参考文献

[1] Qitian Wu et al., DIFFormer: Scalable (Graph) Transformers Induced by Energy Constrained Diffusion, ICLR 2023.

[2] Qitian Wu et al., NodeFormer: A Scalable Graph Structure Learning Transformer for Node Classification, NeurIPS 2022.

[3] Chenxiao Yang et al., Geometric Knowledge Distillation: Topology Compression for Graph Neural Networks, NeurIPS 2022.

作者:吴齐天

本文来自:https://zhuanlan.zhihu.com/p/622970740

免责声明:作者保留权利,不代表本站立场。如想了解更多和作者有关的信息可以查看页面右侧作者信息卡片。
反馈
to-top--btn