本文介绍了一个面向可微近端算法建模的领域建模语言 ∇-Prox ,该语言为大规模优化问题提供了一种全新的解决方法。只需要数行代码,研究人员便可以通过 ∇-Prox ,以一种直观简洁的方式构造可微分的近端求解器,并基于此构造结合 Learning-based 和 Model-based 方法的混合求解器。得益于 ∇-Prox 的智能编译及优化,这些可微分求解器均可以通过包括算法迭代展开,均衡学习,以及强化学习的训练策略进行高效训练。实验证明,使用 ∇-Prox 构造的求解器在端到端计算光学,能源系统规划等领域均取得了显著改进。
论文链接:
https://light.princeton.edu/wp-content/uploads/2023/05/deltaprox-main.pdf
项目链接:
https://light.princeton.edu/publication/delta_prox/
代码链接:
https://github.com/princeton-computational-imaging/Delta-Prox
一、TL & DR
- ∇-Prox 是一个面向可微近端算法建模的编程“语言”(类似 PyTorch 是面向深度学习建模的编程”语言“)。
- ∇-Prox 非常易用,只需要 5-20 行 代码即可解决一个任务。
- ∇-Prox 支持通过强化学习,均衡学习,或迭代展开所实现的双层优化(仅需一行代码)。
- ∇-Prox 具有广泛的应用,这包括图像重建,医学影像,计算光学,能源系统规划等。
二、动机
跨越各种应用领域的任务均可视作大规模优化问题,这包括图形、视觉、机器学习、成像、医疗健康、调度、规划和能源系统预测等。无论应用领域如何,通过探索具体的问题结构,近端算法作为一个形式化的优化方法已经成功地解决了许多现有问题。
虽然这种方法提供了形式化的问题建模方法和收敛保证,但乍一看传统的近端优化方法似乎与近些年强大的深度学习方法格格不入。最近的研究表明,当与一些基于学习的模块结合时,基于近端算法的优化方法可以成为一种有效的且可解释的混合方法。
然而,手工地为不同的任务实现和测试这种混合方法需要同时具备近端优化和深度学习的专业知识,这往往容易出错且耗时。此外,为了训练可学习模块,简单地展开近端算法以产生可以微分的计算图,会导致内存消耗的爆炸式增长,从而使得基于批处理的训练变得非常具有挑战性。
为此,北京理工大学与普林斯顿大学、德国弗劳恩霍夫研究所联合提出了 ∇-Prox,一种用于可微近端算法建模的领域建模语言和编译器。∇-Prox 允许用户简洁地指定优化目标,并智能地编译一个计算与内存效率高的可微近端求解器。∇-Prox 支持构建与训练融合近端算法和可学习模块的混合求解器。∇-Prox 不仅支持包含微分优化的迭代展开训练策略,还支持了具有常数内存复杂度的均衡学习,以及可用于不可微分模块的强化学习策略。
仅需数行代码,∇-Prox 在一系列问题展示了其构造的高性能求解器,这包括端到端计算光学、压缩磁共振成像,以及综合能源系统规划等。
三、∇-Prox 初探
为了更好的介绍 ∇-Prox ,我们首先考虑一个简单的图像恢复逆问题,例如图像反卷积问题。在这个问题中,我们的目标是给定一个模糊图像,恢复出清晰的图像。这可以通过求解一个大规模优化问题进行解决:
其中平方项描述了清晰图像再经过一次卷积应当与观测保持一致 g(x) 则是一个隐式的先验项,它要求变量 x 需要“是一个图像”。
通过 ∇-Prox,我们可以以一种非常直观简洁的方式对其进行求解。
1.首先 import dprox 这个 Python 包,然后生成一组模拟数据,包括观测变量 y 和卷积的点扩散函数
from dprox import *
from dprox.utils import *
from dprox.contrib import *
img = sample()
psf = point_spread_function(15, 5)
b = blurring(img, psf)
2.接着,以和公式几乎完全对应的形式描述出这个优化问题
x = Variable()
data_term = sum_squares(conv(x, psf) - b)
reg_term = deep_prior(x, denoiser='ffdnet_color')
prob = Problem(data_term + reg_term)
3.一行代码指定算法进行求解
prob.solve(method='admm', x0=b)
进一步的,我们也可以通过 compile 这个函数得到一个 solver,并对 solver 进行特化,然后用于构造端到端的可学习求解器。
solver = compile(data_term + reg_term, method='admm')
rl_solver = specialize(solver, method='rl')
rl_solver = train(rl_solver, **training_kwargs)
四、原理概述
给定一个目标函数,∇-Prox 的求解过程主要分为两个阶段:
- compilation : 求解器编译
- specialization : 求解器特化
4.1 Compilation (编译)
编译过程是一个多阶段的过程。这个过程包括问题变换,问题划分,问题预处理,以及求解器生成,四个阶段。
- 问题变换阶段通过一个智能的翻译过程将原问题转化成一个能够被更高效求解的等价形式。
- 问题划分阶段负责将优化目标中的多个约束项划分成两组以满足指定算法的通用求解输入格式。
- 问题预处理会对问题进行预处理以达到加速算法收敛的目的。
- 求解器生成阶段负责具体的代码生成。
4.2 Specialization (特化)
特化阶段是 ∇-Prox 支持构建可微求解器的关键。该阶段通过多种不同的策略对求解器进行特化,使其能够高效地支持端到端的联合训练。例如,联合优化光学成像器件与后处理算法,训练适用于特定领域的先验网络,训练自动调参的策略网络等。
实现这一目的的核心则是三种不同的训练策略
- 迭代展开:展开近端算法的迭代过程对可学习变量进行训练。
- 均衡学习:具有常数内存复杂度的训练算法。
- 强化学习:利用强化学习解决不可微分路径的训练问题。
所有的这些策略都被集成进 ∇-Prox ,并且只需要一行代码即可进行调用
new_solver = specialize(solver, method='deq')
五、实验
我们在包括端到端计算光学,压缩核磁共振成像,图像去雨,综合能源系统规划等多个任务验证了∇-Prox的有效性。
六、欢迎体验
代码链接:
https://github.com/princeton-computational-imaging/Delta-Prox
参考
【1】∇-Prox: Differentiable Proximal Algorithm Modeling for Large-Scale Optimization Zeqiang Lai, Kaixuan Wei, Ying Fu, Philipp Härtel, and Felix Heide. ACM Transactions on Graphics, SIGGRAPH, 2023.
【2】ProxImaL: Efficient Image Optimization Using Proximal Algorithms F. Heide,S. Diamond,M. Niessner,J. Ragan-Kelley, W. Heidrich, and G. Wetzstein. ACM Transactions on Graphics, SIGGRAPH, 2016.
【3】TFPnP: Tuning-free Plug-and-Play Proximal Algorithm with Applications to Inverse Imaging Problems Kaixuan Wei, Angelica Aviles-Rivero, Jingwei Liang, Ying Fu, Hua Huang, Carola-Bibiane Schönlieb. Journal of Machine Learning Research, 2022.
【4】Deep Plug-and-Play Prior for Hyperspectral Image Restoration Zeqiang Lai, Kaixuan Wei, Ying Fu. Neurocomputing, 2022.
作者:赖泽强
来源:公众号【别有一洞天】
Illustration by IconScout Store from IconScout