摘要
arXiv:2406.12120v2 宣告类型: 替换交叉
摘要:扩散模型是一种强大的生成模型,允许对生成样本的特征进行精确控制。虽然这些在大型数据集上训练的扩散模型已经取得了成功,但在下游微调过程中常常需要引入额外的控制。这些强大的模型被视为预训练的扩散模型。本文提出了一种基于强化学习(RL)的新方法,使用包含输入和标签的离线数据集来添加这些控制。我们将此任务形式化为一个RL问题,学习自离线数据集的分类器和相对于预训练模型的KL散度作为奖励函数。我们的方法$\textbf{CTRL}$(使用$\textbf{R}$奖$\textbf{L}$学习条件化预训练扩散模型)生成最大化上述奖励函数的软最优策略。我们正式证明,我们的方法在推断过程中使得能够在有额外控制的条件下进行采样。基于RL的方法相对于现有方法具有多项优势。与无分类器引导相比,它提高了采样效率,并且通过利用输入与附加控制之间的条件独立性,极大地简化了数据集的构建。此外,与分类器引导不同,它消除了需要从中间状态训练分类器以获得额外控制的需要。代码可在https://github.com/zhaoyl18/CTRL获取。