摘要

增强大型语言模型(LLMs)推理能力的一种方式是使用思想连(CoT)注释执行有监督微调(SFT)。然而,该方式不能展示足够强的泛化能力,因为训练仅依赖于给定的CoT数据。在数学问题求解中,例如,在训练数据中每个问题仅有一个注释的推理路径。直观上,算法最好从来自给定问题的多个注释推理路径上学习。为了解决这个问题,我们提出一个简单但有效的方式,称为强化微调(ReFT)来增强学习的LLMs进行推理的泛化能力,使用数学问题求解作为示例。ReFT首先使用SFT预热模型然后使用在线强化学习,本文中指定为PPO算法,来进一步微调模型,其中,给定问题,会自动采样大量推理路径,并且奖励自然地来自真实答案。在GSM8K、MathQA和SVAMP数据集上进行的广泛实验表明,ReFT的表现明显优于SFT,通过结合多数投票和重新排名等推理时间策略,可以进一步提高性能。请注意,ReFT通过从与SFT相同的训练问题中学习来获得改进,而不依赖于额外的或增强的训练问题。这表明ReFT具有更强的泛化能力。论文地址 代码

引言

解决数学问题的最先进方法(Luo等人,2023;Wang等人,2023a)采用监督微调(SFT)来使用思维链(CoT)注释训练模型(Wei等人,2022)。如图1所示,CoT注释概述了解决数学问题的中间推理步骤。

通常训练数据中,每个问题由一个CoT注释,即,一条正确的推理路径,在SFT中利用。我们观察到,这可能导致SFT模型相对弱的泛化能力。它常常是同样的问题存在多个有效CoT注释的情况,强调了需要更加强大的微调方法。为了解决这个问题,我们提出一个简单但有效的方式,称为强化微调(ReFT)(图1底部)。

ReFT开始于涉及一两个回合的监督微调(SFT)的预热阶段(图1,阴影框)。该初始阶段使模型具备一定程度地生成数学问题正确回复的能力,如先前工作所证明(Cobbe等,2021a)。接下来,ReFT 通过利用在线强化学习 (RL) 算法 (Sutton and Barto, 2018),特别是本文中的近端策略优化 (PPO) (Schulman et al., 2017) 进一步完善模型。通过这种方式,ReFT能够对多个正确的推理路径或CoT注释进行采样,并从中学习(图2,右)。

由于训练数据包含真实答案,因此在训练 PPO 时可以自然地从中得出黄金奖励。因此,不需要单独训练奖励模型。相比之下,RLHF(Ouyang 等人,2022 年)必须使用从人工标记数据中学习到的奖励模型。

在预热阶段,ReFT 通过监督学习获得一定的准确率。在强化学习阶段,ReFT 通过对各种 CoT 推理路径进行采样,通过强化学习进一步增强其能力。这样,ReFT 获得的监督信号比 SFT 丰富得多。这种方法使 ReFT 能够大大提高数学问题解决的泛化能力(Gao et al.,2018;Brown et al.,2020)。请注意,ReFT 通过使用相同的训练问题,而不依赖额外或增强的训练问题,其表现优于 SFT。事实上,ReFT 与这种数据工程并不冲突,可以与之无缝结合。

我们的贡献如下:

  • 我们介绍了一种新的微调方法,即强化微调(ReFT),它利用强化学习来解决数学问题。当在相同的数据集上训练时,与传统的监督微调相比,ReFT表现出更强的泛化能力。
  • 我们使用两个基础模型CodeLLAMA(Roziere等人,2023)和Galactica(Taylor等人,2022)在三个标准数据集上进行了广泛的实验:GSM8K(Cobbe等人,2021a)、MathQA(Amini等人,2019)和SVAMP(Patel等人,2021)。我们的实验涵盖了自然语言和基于程序的 CoT,证明了 ReFT 的性能和泛化能力的显著提升。
  • 此外,我们证明了ReFT在推理时受益于多数投票(Wang等人,2023b)和奖励模型重新排序(Uesato等人,2022),从而进一步提高了其性能。

相关工作

数学问题求解 最近的研究工作集中在CoT快速设计和数据工程上。他们中的大多数人试图使CoT变得全面和细粒度,以逐步呈现推理解决方案(Nye等人,2021;Fu等人,2023;Zhou等人,2023b;Khot等人,2023年;Zelikman等人,2022;Imani等人,2023)。Gao等人(2023)进一步提出使用Python程序作为CoT提示,展示了比自然语言CoT更准确的推理步骤和显著改进(Wei等人,2022)。Zhou等人(2023a)介绍了一种提示方法,该方法生成代码以验证GPT-4(OpenAI,2023)的中间推理步骤,从而在GSM8K(Cobbe等人,2021a)和MATH(Hendrycks等人,2021)上实现了最先进的性能。另一项工作侧重于提高CoT的质量(Wang等人,2023a;Liu等人,2023;Yu等人,2023),并增加OpenAI的ChatGPT(gpt-3.5-turbo)或gpt-4的CoT数据量(Luo等人,2023,Yue等人,2023年)。

强化学习 我们的工作主要与最近将PPO(Schulman等人,2017)应用于自然语言过程以调整人类偏好的工作有关(欧阳等人,2022)。从那时起,已经提出了几种训练算法来有效地改善对齐,包括直接偏好优化(DPO)(Rafailov等人,2023)、身份偏好优化(IPO)(Azar等人,2023年)和Kahneman Tversky优化(KTO)(Ethayarajh等人,2023。除了对齐的目的外,我们的目标是采用强化学习作为微调范式,以提高传统监督微调的性能。

具体来说,为了解决数学问题,Uesato 等人 (2022) 和 Lightman 等人 (2023) 训练了一个基于结果或基于过程的奖励模型来执行重新排序 (Cobbe 等人,2021a),从而实现了比 SFT 和多数投票 (Wang 等人,2023b) 更好的性能。虽然我们的方法旨在提高策略本身的性能,但这些奖励模型重新排序方法可以轻松集成到最终的策略模型中。

方法

在这项工作中,我们专注于使用Python的自然语言CoT(N-CoT)(Wei等人,2022)(图1)和基于程序的CoT(Gao等人,2023)(P-IoT)。Gao等人(2023)提出了基于程序的CoT来解决数学问题。我们可以简单地执行程序来获得答案。为了确保清晰性并避免歧义,我们分别使用术语N-CoT和P-CoT来表示自然语言和基于程序的CoT。

强化微调

所提出的强化微调(ReFT)过程包括两个阶段:预热阶段和强化学习阶段。整体算法如算法1所示。

预热 在这个阶段,该策略在由“(问题,CoT)”元组(x,e)组成的数据集上进行了几个回合的微调。它使模型具有基本的解决问题的能力,以生成适当的响应。从形式上讲,CoT生成过程可以分解为一系列下一个标记预测动作。最后一个动作标记 <eos><eos> 表示生成过程终止。CoT e写作如下:

e=[a1,a2,,aL1,aL=<eos>]e = [a_1,a_2,\dots,a_{L-1},a_L=<eos>]

其中L表示最大长度。在时间步t,动作 ata_t 采样自策略 πθ(st)\pi_{\theta}(\cdot|s_t) ,其中 ata_t 可以是词汇表中的任何标记,状态 sts_t 包括问题中的所有标记和迄今为止生成的所有标记。在每个动作之后,得到的状态 st+1s_{t+1} 是当前状态 sts_t 和动作 ata_t 在以下位置的拼接:

st+1={x,t=0[st,at],1tLs_{t+1} = \begin{cases} x, \quad &t = 0 \\ [s_t,a_t], \quad & 1 \le t \le L \end{cases} \cdot

由于产生的动作是 <eos><eos> 标记,因此生成的状态 sL+1s_{L+1} 是终止状态,生成过程结束。使用这种符号,样本的损失函数可以写成:

LSFT(θ)=EeD[t=1Tlog(πθ(atst))](1)L_{SFT}(\theta) = -E_{e\sim D} \left[\sum_{t=1}^T \mathcal{log}(\pi_{\theta}(a_t|s_t))\right] \tag{1}

强化学习 在这个阶段,该策略通过使用由(问题,答案)元组:(x,y)组成的数据集进行在线自学习来提高其性能。具体来说,策略模型通过反复采样响应(图2)、评估响应的答案正确性并以在线方式更新其参数来学习(算法1中的第7-14行)。我们采用PPO(Schulman等人,2017)和截断目标算法进行训练。根据Ziegler等人(2019)的研究,价值模型 VϕV_{\phi} 是通过在预热阶段后的策略模型 πθ\pi_{\theta} 的最后隐藏状态上附加一个线性值头来构建的。对于所有导致非终止状态的操作,奖励为0。在终止状态,我们使用一个奖励函数,直接比较从该状态的CoT中提取的答案和地面真实答案y。这里,如果答案被认为是正确的,则奖励函数返回1,否则返回0。在答案都是数字的数据集上,当答案可以提取并且是数字类型时,可以应用0.1的部分奖励(Zhong等人,2017;Le等人,2022)。对于 1tL1\le t\le L ,我们写作

r(st,at,st+1)={1,EXTRACT(st+1)=y0.1,EXTRACT(st+1)null,y0,EXTRACT(st+1)=nullr(s_t,a_t,s_{t+1}) = \begin{cases} 1, \quad & EXTRACT(s_{t+1}) = y \\ 0.1, \quad &EXTRACT(s_{t+1}) \ne null, \ne y \\ 0, \quad &EXTRACT(s_{t+1}) = null \end{cases}

这种部分奖励可以帮助减少从稀疏奖励中学习的影响(Riedmiller等人,2018;Trott等人,2019)。此外,根据Zheng等人(2023)的研究,我们的总奖励是奖励函数得分, 与学习的RL策略和初始策略之间的Kullback-Leibler(KL)散度之和(Kullback和Leibler,1951),该散度由系数因子 β\beta 缩放。

rtotal(st,at,st+1)=r(st,at,st+1)βKL(πθ(st),πθ(0)(st))r_{total}(s_t,a_t,s_{t+1}) = r(s_t,a_t,s_{t+1}) - \beta\text{KL}(\pi_{\theta}(\cdot|s_t),\pi_{\theta}^{(0)}(\cdot|s_t))

广义优势估计(Schulman等人,2018)用于优势计算:

A^t=l=0Lt(γλ)lδt+l,\hat{A}_t = \sum_{l=0}^{L-t}(\gamma\lambda)^{\mathcal{l}}\delta_{t+\mathcal{l}},

其中,时序差分(TD)定义为

δt=Vϕ(st)+rtotal(st,at,st+1)+γVϕ(st+1)\delta_{t'} = -V_{\phi}(s_{t'}) + r_{total}(s_{t'},a_{t'},s_{t'+1})+ \gamma V_{\phi}(s_{t'+1})

终止状态值 Vϕ(sL+1):=0V_{\phi}(s_{L+1}) := 0λ(0,1]\lambda \in (0,1] 是奖励的折扣因子,γ[0,1]\gamma \in [0,1] 是TD的折扣因子。为了估计回报,我们利用 λreturnR^t\lambda -return \hat{R}_t,可以写为广义优势估计和值估计之和:

R^t=A^t+Vϕ(st)\hat{R}_t = \hat{A}_t + V_{\phi}(s_t)

最后,策略和值目标可以写为如下两个方程:

Lpolicy(θ)=Eeπθold[min(πθ(atst)πθold(atst)A^t,  clip(πθ(atst)πθold(atst),1ϵ,1+ϵ)A^t)]L_{policy}(\theta) = -E_{e\sim \pi_{\theta_{old}}}\left[min\big(\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\hat{A}_t, \;clip(\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}, 1-\epsilon, 1+\epsilon)\hat{A}_t\big)\right]

Lvalue(ϕ)=12Eeπθold[max(Vϕ(st)R^t2,  clip(R^tVϕ(st),A^tϵ,A^t+ϵ)2)]L_{value}(\phi) = \frac{1}{2} E_{e\sim \pi_{\theta_{old}}}[max\big( ||V_{\phi}(s_t) - \hat{R}_t||^2, \; ||clip(\hat{R}_t - V_{\phi}(s_t),\hat{A}_t - \epsilon, \hat{A}_t + \epsilon)||^2\big)]

其中, πθold,Vϕold\pi_{\theta{old}}, V_{\phi_{old}} 用于采样CoT和计算 A^t,R^t\hat{A}_t,\hat{R}_t

统一损失函数是上述目标的加权和。

LRL(θ,ϕ)=Lpolicy+αLvalue(2)L_{RL}(\theta,\phi) = L_{policy} + \alpha L_{value} \tag{2}

其中 α\alpha 是值目标的系数。