来给大家介绍一下我们被接收为ICLR Spotlight的新工作。这个work从2021年春开始一直做到2022年秋,中间克服了许多技术障碍,没想到第一次投稿就好评如潮(分数8886),也恭喜子昊的坚持得到回报。这篇工作的核心贡献在于,正式定义了domain adaptation中的域索引(domain index),精心设计了推断(infer)domain index的算法(variational domain indexing, 即VDI),并且证明了我们的算法可以推断出最优的domain index。由于推断出来的domain index带来的free lunch,domain adaptation的性能也得到了提高。
论文链接:
http://wanghao.in/paper/ICLR23_VDI.pdf
代码链接:
https://github.com/Wang-ML-Lab/VDI
什么是domain index:domain index的说法最早在我们的ICML 2020论文“Continuously Indexed Domain Adaptation”(CIDA)中提出(有兴趣的看官欢迎移步我们讲CIDA的知乎帖子)。最直观的例子就是在医疗应用里面,不同年龄的人可以看成是不同的domain,而这个“年龄”其实就是domain的一个索引(index),也就是我们说的domain index(域索引)。如下图。
有意思的是,domain index其实是一个连续的概念,所以自然而然地包含了domain的远近信息。比如上面说的“年龄”可以作为一个一维的domain index,年龄18和19距离很近,而18和80却距离很远。我们之前在CIDA(大致的CIDA模型如下图)上的实验发现,如果已知这个domain index,我们可以很好地做到连续域上的domain adaptation,从而大幅提高准确率。
比如把模型从年龄0 ~ 20的病人(source domains),adapt到年龄20 ~ 80的病人(target domains),或者从年龄0 ~ 15以及50 ~ 80的病人(source domains),adapt到年龄15 ~ 50的病人(target domains),如下图。
那么问题来了,如果这个domain index是未知的,咋办? 最理想的情况当然是,我们能够把这个domain index作为隐变量(latent variable),通过无监督(unsupervised)的方式把它推断(infer)出来。如果这个方案可行,我们就免费拿到了一个重要的额外的信息,从而既可以提高domain adaptation的准确率,又能提高它的可解释性。
Domain Index的正式定义: 在推断domain index前,我们要先定义清楚,什么才算是domain index,然后才能设计推断它的方法。
这里我们首先引入了两种domain index,local domain index(用u表示)和global domain index(用β表示)。我们规定,虽然同一个domain里的不同数据点(data point)可以有不同的local domain index,但是同一个domain里的所有数据的global domain index必须是是相同的。也就是说,local domain index是一个instance-level的变量,而global domain index是一个domain-level的变量。下面的图是一个具体的例子,展示了global domain index β、local domain index u、数据x之间的关系。
如果 β 和 u 满足上述三个条件,我们就把它们分别称为global domain index和local domain index。这三个条件可以用下面的数学公式表示:
方法的整体思路:定义完domain index后,下一个问题自然就是,如何能在无监督(完全不知道domain index)的情况下,有效地推断出符合上面三段定义的domain index β 和 u 呢?这时,就要请出adversarial Bayesian deep learning model(对Bayesian deep learning感兴趣的同学可以看看我们之前的帖子)来解决这个问题。
在Bayesian deep learning里面,或者更加传统的probabilistic graphical model里面,我们会分两步走:
- 第一步是首先假设一下已知变量(observed variable)是如何从隐变量(latent variable,即未知的变量)一步步生成的。我们一般把这个叫做生成过程(generative process)。
- 然后第二步,就是通过贝叶斯推断(Bayesian inference)的方式来根据已知变量来倒推隐变量。
在我们目前的问题里,数据x以及标签y都是已知变量,而我们的encoding z以及domain index β和u则是隐变量。那么很自然,我们的目的就是已知各个domain里的数据x以及标签y,然后想推断出encoding z以及domain index β和u。注意,在domain adaptation里面,只有source domain才有已知的标签y。target domain只有数据x。
生成过程: 根据这个整体思路,我们就首先假设一下各个变量生成过程(如下图左边):
上面的目标函数可能有点冗长难懂,直接看下图可能会好些。直观地讲,我们可以把优化这个ELBO,看成学习很多子网络来对输入数据x进行编码 (encode)和重构(reconstruct)的过程,关键在于,在这个编码和重构的过程中,需要聪明地把domain index β 和 u 建模进去。
对贝叶斯推断(Bayesian Inference)熟悉的同学可能已经发现了,这个其实就是我们之前说的(广义的)贝叶斯深度学习(Bayesian Deep Learning)的思路:用深度模块(deep component)来处理高维信号x(比如图片),然后用概率图模块(graphical component)来表示各个随机变量之间的条件概率关系(比如图片x及其对应的encoding z和domain index β、u的关系)。
为了满足第一个要求,我们需要借鉴对抗域迁移(adversarial domain adaptation)的思想,在上图的基础上,再加上一个discriminator,然后对抗地(adversarially)训练整个网络,使得encoder能把不同domain的x映射到一个encoding空间,然后让这个discriminator无法从他们的encoding z来分辨出数据是来自于哪个domain的。我们把这个操作叫做encoding的对齐(alignment),即把不同的domain的encoding分布对齐起来,让他们互相重叠,这样就可以方便不同domain共享一个predictor了(比如分类器或者回归器)。加上discriminator之后的神经网络架构如下:。
最终的目标函数: 相应地,我们最终的目标函数也从一个简单的优化问题(最大化ELBO)变成了一个minimax game:
学到了啥有意思的domain index: 既然有了理论保障,那么接下来我们可以看一下,如果按照上面的方法训练模型,我们能推断出来什么样的global domain index呢?我们用的第一个数据集是之前CIDA用的Circle数据集。这个数据集包含了30个domain,如下图所示。左下图是用颜色标记了domain index,我们可以看到颜色是渐变的,也就是说ground-truth的domain index是从1到30。绿色框里表示的是6个source domain,其他部分为target domain。右下图是用蓝色和红色标记了标签(label),可以看出来这是个二分类的数据集,蓝色表示正例,红色表示负例。
下面的图展示了我们的VDI学习到的domain index 和ground-truth domain index的对比。可以看到,我们学到的domain index和真正的domain index是高度吻合的,correlation达到了0.97。有趣的是,跟CIDA不一样,我们在训练VDI过程中,并没有用到任何的domain index,所有的domain index都是VDI模型自己以无监督的方式推断出来的。
除了Circle这个toy dataset,我们还测试了现实的数据集。比如之前我们在GRDA构建的TPT-48温度预测数据集。这个数据集有美国大陆48个州的每月气温。这里的任务(task)是,根据前6个月的气温,预测后6个月的气温(如下图左边)。我们把一部分州的数据作为source domain(如下图黑底白字的州),然后把其他州作为target domain(如下图白底黑字的州)。我们把target domain分成3个层级,level-1、level-2、和level-3的target domain分别表示离source domain最近、次近、和最远的target domain。
有意思的是,即使在无监督(未知正确的domain index)的情况下,我们的VDI依然能够学出有意义的domain index。比如下图左边,我们画出来VDI学出来的2维的domain index β。下面每个点的坐标位置表示的是我们VDI学到的2维domain index,而颜色则表示对应的domain(州)真实的纬度。我们可以看到,我们domain index的第一维(横轴)和真实的每个州的纬度高度吻合。比如纽约(NY)和新泽西(NJ)纬度距离比较近,而且都在比较北边(如下面的右图),那么对应的,他们的domain index也很接近。相反,佛罗里达(FL)离NY和NJ的纬度距离都比较远,对应地,它的domain index也离NY和NJ比较远。
另一个真实数据集是CompCar,CompCar里包含了各种车的照片,这些照片有2维真实的domain index,拍照的角度(比如正面照、侧面照、后面照等等)以及出厂年份(比如2009)。类似地,我们把VDI学到的2维domain index画到下图。下面每个点的坐标位置表示的是我们VDI学到的domain index,而颜色则表示真实的拍照角度(左图)和出厂年份(右图)。可以看到,即使是在无监督的情况下,我们学出来的domain index依然和真实的拍照角度和出厂年份高度相关。
提高domain adaptation准确率: 当然除了能学出有意思的domain index,VDI自然可以利用这些学到的domain index,来提高domain adaptation的准确度。下面的表格是TPT-48的温度预测误差(MSE)对比。我们可以看到VDI几乎在所有层级(level)的target domain都能有准确率的提高。
写在最后
熟悉的同学可能可以看出来,这个VDI其实有点像是我们ICML’20的”Continuously Indexed Domain Adaptation”(CIDA)的逆问题,同时也可以看成是和CIDA这类算法的互补的问题。CIDA是想通过已知的domain index来提高连续域adaptation的准确度,而VDI则解决了一个更general的问题,也就是当这个domain index未知的时候,应该如何去推断出来。而且一旦推断出来domain index,我们就可以放心地继续使用CIDA来实现连续域(甚至是传统的离散域)的adaptation准确率的提升了。
Paper:https://arxiv.org/pdf/2302.02561.pdf or http://wanghao.in/paper/ICLR23_VDI.pdf
OpenReview:https://openreview.net/forum?id=pxStyaf2oJ5
YouTube Video:https://www.youtube.com/watch?v=xARD4VG19ec
Bilibili Video:https://www.bilibili.com/video/BV13N411w734/?share_source=copy_web
GitHub Link:https://github.com/wang-ML-Lab/VDI
作者:王灏
https://www.zhihu.com/question/557295083/answer/2977965268