摘要
arXiv:2410.08893v3 宣布类型:替代交叉
摘要:基于模型的强化学习(RL)提供了一种解决大多数无模型RL算法的数据效率低问题的方法。然而,学习一个稳健的世界模型通常需要复杂的深层结构,这在计算上成本高昂且难以训练。在世界模型中,序列模型在准确预测中起着关键作用,各种架构已被探索,每个架构都有其自身的挑战。目前,基于递归神经网络(RNN)的世界模型难以处理梯度消失问题和捕捉长期依赖性。相比之下,变压器(Transformers)由于自注意力机制的二次内存和计算复杂度,放大为 $O(n^2)$,其中 $n$ 是序列长度,存在挑战。
为了应对这些挑战,我们提出了一种基于状态空间模型(SSM)的世界模型 Drama,特别利用了 Mamba,该模型实现了 $O(n)$ 的内存和计算复杂性,同时有效地捕捉长期依赖性,并允许使用较长序列进行高效的训练。我们还介绍了一种新的采样方法,以减轻早期训练阶段错误世界模型导致的次优性。结合这些技术,Drama 在 Atari100k 基准测试中实现了与当前最先进的(SOTA)基于模型的 RL 算法相竞争的标准化得分,仅使用一个包含 700 万个参数的世界模型。Drama 可在标准台式机等现成硬件上访问和训练。我们的代码可在 https://github.com/realwenlongwang/Drama.git 获取。