切换JAX,强化学习速度提升4000倍!牛津大学开源框架PureJaxRL,训练只需GPU

2023-05-11 12:40 434 阅读 ID:1048
新智元
新智元

还在为强化学习运行效率发愁?无法解释强化学习智能体的行为?

最近来自牛津大学Foerster Lab for AI Research(FLAIR)的研究人员分享了一篇博客,介绍了如何使用JAX框架仅利用GPU来高效运行强化学习算法,实现了超过4000倍的加速;并利用超高的性能,实现元进化发现算法,更好地理解强化学习算法。

文章链接:https://chrislu.page/blog/meta-disco/

代码链接:https://github.com/luchris429/purejaxrl

作者团队开发的框架PureJaxRL可以极大降低进入Deep RL研究的算力需求,使学术实验室能够使用数万亿帧进行研究(缩小了与工业研究实验室的差距),独立研究人员也可以利用单个GPU进行开发。

文章作者Chris Lu是牛津大学博士二年级学生,工作重点是将进化启发(evolution-inspired)的技术应用于元学习和多智能体强化学习,曾在DeepMind实习。

使用PureJaxRL实现超过4000倍加速

GPU is all you need

  大多数Deep RL的算法同时需要CPU和GPU的计算资源,通常来说,环境(environment)在CPU上运行,策略神经网络运行在GPU上,为了提高wall clock速度,开发者往往使用多个线程并行运行多个环境。


但如果是用JAX的话,可以直接将环境向量化(vectorise),并将其在GPU上运行,而无需使用CPU的多线程。


不仅可以避免在CPU和GPU之间传输数据以节省时间,如果使用JAX原语来编写环境程序,还可以使用JAX强大的vmap函数来立即创建环境的矢量化版本。


虽然在JAX中重写RL环境可能很费时间,但幸运的是,目前已经有一些库提供了各种环境:


Gymnax库包括了多个常用的环境,包括经典的控制任务,Bsuite任务和Minatar(类似Atari的)环境。

                                               链接:https://github.com/RobertTLange/gymnax

研究人员选择Gymnax作为测试和评估代码的首选库,在这篇文章中的示例用的也是Gymnax,库中还包括许多其他非常简洁的功能,并且非常容易使用。

Brax是使用JAX运行类似Mujoco的连续控制环境的方法,该库包含许多强化学习环境,可以对标类似的连续控制环境,如HalfCheetah和Humanoid,并且也是可微的!

                                                    链接:https://github.com/google/brax

Jumanji包含许多令人特别炫酷、简单和行业驱动的环境,库中的许多环境都直接来自于行业设置,确保这里提供的环境是实用的并且与现实世界相关,具体问题包括组合问题,如著名的旅行推销员问题或3D装箱。

                                                链接:https://github.com/instadeepai/jumanji

Pgx有许多流行的桌面游戏和其他环境,包括Connect 4、围棋、扑克!

                                                    链接:https://github.com/sotetsuk/pgx

在Gymnax的测速基线报告显示,如果用numpy使用CartPole-v1在10个环境并行运行的情况下,需要46秒才能达到100万帧;在A100上使用Gymnax,在2k 环境下并行运行只需要0.05秒,加速达到1000倍!

这个结论也适用于比Cartpole-v1更复杂的环境,例如Minatar-Breakout需要50秒才能在 CPU 上达到100万帧,而在 Gymnax 只需要0.2秒。

这些实验结果显示了多个数量级的改进,使学术研究人员能够在有限的硬件上高效地运行超过数万亿帧的实验。

在JAX中端到端地进行所有操作有几个优势:

  • 在加速器上的矢量化环境运行速度更快。
  • 通过将计算完全保留在GPU上,可以避免在CPU和GPU之间来回复制数据的开销,通常也是性能的一个关键瓶颈。
  • 通过JIT编译实现,可以避免Python的开销,有时会阻塞发送命令之间的GPU 计算。
  • JIT 编译通过运算符融合(operator fusion)可以获得显著的加速效果,即优化了GPU上的内存使用。
  • 多线程的并行运行环境很难调试,并且会导致复杂的基础设施。

为了证明这些优势,作者在纯JAX环境中复制了CleanRL的PyTorch PPO基线实现,使用了相同数量的并行环境和相同的超参数设置,并且没有利用海量环境矢量化的优势。

在Cartpole-v1和MinAtar-Breakout中运行5次,训练过程如下。

Cartpole-v1和 MinAtar-Breakout上的CleanRLvs JAX PPO,给定相同的超参数和帧数,得到了几乎相同的结果。

将x轴从帧替换为wall-clock time(某个线程上实际执行的时间)后,在没有任何额外并行环境的情况下,速度提升了10倍以上。

      Cartpole-v1和 MinAtar-Breakout上的CleanRL vs JAX PPO,得到了相同的结果,但是快了10倍以上!

并行运行多个智能体

虽然可以从上述技巧中得到相当不错的加速效果,但与标题中的4000倍加速仍然相去甚远。

通过向量化整个强化学习训练循环以及之前提到JAX中的vmap,可以很容易地并行训练多个智能体。


rng = jax.random.PRNGKey(42)
rngs = jax.random.split(rng, 256)
train_vjit = jax.jit(jax.vmap(make_train(config)))
outs = train_vjit(rngs)

此外,还可以使用JAX提供的pmap函数在多个 GPU 上运行,在此之前,这种跨设备的并行化和向量化,尤其是在设备内部的并行化和向量化,是一个非常令人头疼的问题。

Cartpole-v1和 MinAtar-Breakout 上的CleanRL vs Jax PPO,可以将智能体训练本身并行化。在 Cartpole-v1上,只需要用训练一个CleanRL智能体的一半时间来训练2048个智能体。

如果正在开发一个新的强化学习算法,那么就可以在单个GPU上同时对具有统计学意义的大量种子进行快速训练。

除此之外,还可以同时训练成千上万的独立智能体,在作者提供的代码中,还展示了如何使用进行快速超参数搜索,也可以将其用于进化元学习!

Deep RL的元进化发现

元学习,或者说「学会学习」,通过发现可以应用于广泛任务的一般原则和算法,有潜力彻底改变强化学习领域。

在FLAIR时,作者使用上述计算技术通过进化(evolution)为Meta-RL的新发现提供基础,并有望提高对强化学习算法和智能体的理解,这些优势非常值得探索。

传统的元学习技术,通常使用元梯度或高阶导数,注重只使用少量的样本来快速适应相似但没见过的任务。

虽然这种方法在特定领域内运行良好,但它不能实现通用的学习算法,虽然这种算法可以在更新中处理不同的任务,当试图跨越数百万个时间步和数千个更新进行元学习时,这种局限性变得更加明显,因为基于梯度的方法通常会导致高方差更新,从而影响性能。

另一方面,进化方法提供了一个有前景的选择方案,通过把潜在的问题作为一个黑盒子来处理,并避免显式地计算导数,可以有效地跨越long horizons进行元学习。

进化策略(evolutionary strategies, ES)的主要优势包括:

  • 学习时间步的不可知论(Agnosticism)
  • 不用担心梯度消失或梯度爆炸
  • 无偏见的更新
  • 通常较低的方差
  • 高可并行性

在更高的层面上,这种方法反映了自然界中学习的出现,在自然界中,动物已经进化出在大脑中执行强化学习的基因。

对进化方法的主要批评是,它们可能速度比较慢,并且样本效率较低,往往需要同时评估数千个参数,而这个框架可以通过在有限的硬件上实现快速并行评估来解决这些问题,使元强化学习中的进化成为一个有吸引力的和实用的选择。

一个比较方便的库是evosax(由Gymnax开发者打造),可以很容易地将强化学习训练循环连接到这个库,并完全在GPU上执行极其快速的元进化。

比如说,通过元学习获得Cartpole-v1上 PPO智能体的价值损失函数;在外部循环中,采样这个神经网络的参数(元参数) ,在内部循环中,从头开始训练强化学习智能体,并使用这些元参数对值损失函数进行训练。

在一个Nvidia A40 GPU上,通过超过1000亿帧训练了1024 generations,得到512个智能体,也就是说在9个小时内就在一个 GPU 上训练了超过50万的智能体!

                                            元学习价值距离函数,得到的学习距离函数优于L2
                                                                          元学习值距离函数

实验结果看起来很有趣,看起来一点也不像标准的L2损失,并不对称,甚至不凸,总之,元进化发现框架包括:

  • 使用 Jax 在 GPU 上运行一切
  • 使用进化方法跨整个训练轨迹的元学习。
  • 解释学习的元参数来「发现」关于学习算法的新见解

案例研究

这是一个非常强大的框架,作者团队在FLAIR发表的多篇论文中都用到它来更好地理解强化学习算法的行为。

发现策略优化(Discovered Policy Optimisation)

在过去的十年里,强化学习已经取得了巨大的进步,这些进步中的大部分来自于新算法的不断发展,这些算法是使用数学推导、直觉和实验相结合的方法设计的,这种手工创建算法的方法受到人类理解力和独创性的限制。

相比之下,元学习为自动机器学习方法优化提供了一个工具包,可能会解决这个缺陷。

然而,试图以最小先验结构发现强化学习算法的黑盒方法迄今为止还没有优于现有的手工算法。

                                               论文链接:https://arxiv.org/pdf/2210.05639.pdf

镜像学习(Mirror Learning),可能是一个潜在的中间起点:虽然这个框架中的每一个方法都有理论上的保证,但是区分它们的组件受制于设计。

这篇论文探讨了镜像学习空间中元学习的一种「漂移」(drift)函数,称其为「学习策略优化」(LPO)。

通过对 LPO 的分析,研究人员获得了对策略优化的独到见解,并用它来构造一个新颖的、封闭形式的 RL 算法,发现策略优化(DPO)。

最后在 Brax 环境中的实验证实了LPO和DPO的最先进性能,以及迁移到未见过环境的能力。

Model-Free Opponent Shaping

在非零和博弈(general-sum game)中,self-interested学习智能体的相互作用通常导致集体最坏情况的结果,如重复囚徒困境(IPD)。

为了克服这一问题,一些方法,如对手学习意识学习(LOLA) 形成了对手的学习过程,不过这些方法通常是短视的,因为只有少数步骤可以预测,并且是不对称的,主要是因为它们将其他智能体视为朴素的学习者,并且需要通过白盒访问对手的可微学习算法来计算高阶导数。

                                               论文链接:https://arxiv.org/pdf/2205.01447.pdf

为了解决这些问题,研究人员提出了无模型对手形成算法(M-FOS)。

M-FOS 在一个元游戏中学习,其中每一个元步骤都是潜在的内在游戏的一个回合;元状态由内部策略组成,元策略产生一个新的内部策略,将在下一episode中回合;然后,M-FOS 使用通用的无模型优化方法来学习元策略,以形成长期的对手。

根据经验,M-FOS 近乎最优地利用了朴素的学习者和其他更复杂的文献算法,它是学习知识产权保护中著名的零行列式(ZD)勒索策略的第一种方法。

在相同的设置下,M-FOS 可以获得meta-self-play游戏下的社会最优结果,并可以扩展到高维设置。

Adversarial Cheap Talk

强化学习中的对抗性攻击(RL)通常假定受害者对参数、环境或数据的访问具有高权限。

相反,这篇论文提出了一种称为简单对话 MDP 的全新对抗设置,在这种设置中,对手只需在被攻击者的观察中附加确定性信息,就可以产生最小范围的影响。

                                             论文链接:https://arxiv.org/pdf/2211.11030.pdf

攻击者不能掩盖ground truth,影响潜在的环境动态或奖励信号,引入非平稳性,增加随机性,查看被攻击者的行为,或访问他们的参数。

此外,研究人员在这个设置下提出了一个简单的元学习算法,也称为简单谈话(ACT)训练对手。

实验证明,尽管存在高度约束的设置,一个对手训练的 ACT 仍然可以显着影响受害者的训练和测试表现;对训练时间性能的影响也提供了一个新的攻击向量,并可以观察理解现有强化学习算法的成功和失败模式。

更具体地说,研究人员展示了 ACT 对手能够通过干扰学习者的函数逼近来损害表现,或者通过输出有用的功能来帮助被攻击者提升表现。

最后展示了一个 ACT对手可以在训练时间操纵消息,在测试阶段任意控制被攻击者。

参考资料:

https://chrislu.page/blog/meta-disco/

免责声明:作者保留权利,不代表本站立场。如想了解更多和作者有关的信息可以查看页面右侧作者信息卡片。
反馈
to-top--btn