摘要

虽然扩散模型在图像、音频和视频生成领域获得显著进展,但它们依赖于迭代采样过程,导致生成缓慢。为克服该限制,提出一致性模型。一致性模型直接映射噪声到数据,以生成高质量样本。通过设计,支持快速一步生成,同时允许多步采样,以平衡计算和采样质量。它们还支持零样本(zero-shot)数据编辑,例如图像恢复、着色和超分辨,而无需再这类任务上额外训练。一致性模型可以通过两种方式训练:(1)蒸馏预训练扩散模型,或者(2)作为独立生成模型。通过大量实验证明,它们优于现有扩散模型一步和几步采样的蒸馏技术,对于一步生成,获得CIFAR-10上3.55和ImageNet 64×6464\times64 上6.20的先进FID水平。当独立训练时,一致性模型成为一个新系列的生成模型,在诸如CIFAR-10、ImageNet64x64和LSUN 256x256标准基准上,优于现有的一步、非对抗生成模型。

如下图,数据通过概率流常微分方程(PF-ODE)转为噪声。在PF-ODE轨迹上的任意点 xtx_t 通过学习均可映射到原点x0x_0。这些映射模型被称为一致性模型,它们的输出,经过训练,同一轨迹的点的是一致的。

将ODE轨迹上的任一点映射到原始点

引言

扩散模型,也称基于分数的生成模型,在多个领域获得前所未有的成功,包括图像生成、音频合成和视频生成。扩散模型的一个关键特征是迭代采样过程,渐进地移除随机初始向量的噪声。该迭代过程提供灵活的计算和采样质量平衡,通过使用额外的迭代计算通常生成更好的质量。它也是扩散模型的许多零样本数据编辑功能的关键,使它们能够解决从图像修复、着色、笔划引导图像编辑到计算机断层扫描和磁共振成像等具有挑战性的逆向问题。然而,对于单步生成模型,如GANs、VAEs或正则化流(normalizing flows),扩散模型的迭代生成过程通常需要10-2000次的样本生成计算,导致缓慢推理,并限制实时应用。

本文的目标是创建高效、单步生成模型,并且不牺牲迭代采样的重要优势,例如,当必要时权衡计算和采样质量,也能够执行零样本数据编辑任务。如图1所示,我们构建在连续时间扩散模型中的概率流常微分方程上,它的轨迹平滑地将数据分布转变为易于处理的噪声分布。我们提出学习一个模型,映射任意时间步上的任意点到轨迹的起点。我们模型一个值得注意的属性是自我一致性:同一轨迹上的点映射到同一初始点。因此,模型称为一致性模型。一致性模型允许我们通过一次网络评估(one network evaluation)转换随机噪声向量(ODE轨迹的终点,如图1中的 xTx_T)来生成数据样本(ODE轨迹的初始点,图1中的 x0x_0 )。重要的是,通过改变一致性模型在多个时间步的输出,以更多的计算代价,可以改进样本质量,并且执行零样本数据编辑,类似于扩散模型中的迭代采样。

基于强制自我一致性属性,训练一致性模型提供了两种方法。(1)依赖于使用数值常微分求解器(numerical ODE solvers)和一个预训练扩散模型来生成PF ODE轨迹上的相邻点对。通过最小化这些点对的模型输出差异,可以有效蒸馏扩散模型为一致性模型,允许一次网络评估生成高质量的样本。相反,(2)无需预训练扩散模型,允许独立训练一致性模型。该方法将一致性模型视为独立生成模型系列。重要的是,两种方法都无需对抗训练,都对架构施加了最小限制,允许灵活使用神经网络来参数化一致性模型。

我们在若干图像数据集上证明了一致性网络的有效性,包括CIFAR-100、ImageNet 64x64和LSUN 256x256。经验上,我们观测到作为蒸馏方法,一致性模型在多种数据集上的一步生成优于现有扩散蒸馏方法,如渐进蒸馏。在CIFAR-10中,一致性模型达到新水平,一步生成和两步生成的FIDs分别为3.55和2.93;在ImageNet 64x64中,它取得破纪录的一步和两步网络评估的FIDs分别为6.20和4.70。当作为独立生成模型训练时,一致性模型可以匹配或超过来自渐进蒸馏的样本质量,尽管无需预训练扩散模型。它们还在多个数据集上优于许多GANs和现有的非对抗性、单步生成模型。此外,我们证明一致性模型可以用于执行广泛的零样本数据编辑任务,包括图像去噪、插值、恢复、着色、超分辨率和笔画引导图像编辑(SDEdit, stroke-guided image editing)。

扩散模型

一致性模型受连续时间扩散模型的高度启发。扩散模型通过高斯扰动渐进扰动数据到噪声,然后通过序列去噪步骤从噪声中创建样本。令 pdata(x)p_{data}(x) 表示数据分布。扩散模型起始于用随机微分方程(SDE)扩散 pdata(x)p_{data}(x)

dxt=μ(xt,t)dt+σ(t)dwt,(1)dx_t = \mu(x_t,t)dt + \sigma(t)dw_t\,, \tag{1}

其中, t[0,T]t\in[0,T], T > 0是固定常量, μ(,)\mu(\cdot,\cdot)σ()\sigma(\cdot) 分别为漂移和扩散系数, {wt}t[0,T]\lbrace w_t\rbrace_{t\in[0,T]}表示标准布朗运动。xtx_t 的分布表示为 pt(x)p_t(x), 因此 p0(x)pdata(x)p_0(x) \equiv p_{data}(x)。该SDE的一个显著性质是存在一个常微分方程(ODE),即概率流(PF)ODE,其在 t 处采样的解轨迹分布根据 pt(x)p_t(x)

dxt=[μ(xt,t)12σ(t)2logpt(xt)]dt.(2)dx_t = [\mu(x_t,t) - \frac{1}{2}\sigma(t)^2\triangledown\mathcal{log}p_t(x_t)]dt\,. \tag{2}

这里, logpt(x)\triangledown\mathcal{log}p_t(x)pt(x)p_t(x)的分数函数,因此扩散模型也称为基于分数的生成模型(score-based generative models)。

通常,方程(1)中的SDE被设计使得 pTxp_{T}x 为接近容易处理的高斯分布 π(x)\pi(x)。我们采用EDM中的设置,其中 μ(x,t)=0\mu(x,t) = 0 , σ(t)=2t\sigma(t) = \sqrt{2t} 。此时,我们得到 pt(x)=pdata(x)N(0,t2I)p_t(x) = p_{data}(x) \odot \mathcal{N}(0,t^2\mathrm{I}) ,其中 \odot 表示卷积操作,且 π(x)=N(0,T2I)\pi(x) = \mathcal{N}(0,T^2\mathrm{I}) 。对于采样,首先通过分数匹配(score matching)训练分数模型 sϕ(x,t)logpt(x)s_{\phi}(x,t) \approx \triangledown\mathcal{log}p_t(x),然后带入到方程(2)中获得PF ODE的经验估计,其采用如下形式:

dxtdt=tsϕ(xt,t).(3)\frac{dx_t}{dt} = -ts_{\phi}(x_t,t)\,. \tag{3}

方程(3)称为经验PF ODE。接着,采样 x^Tπ=N(0,T2I)\hat{x}_T \sim \pi = \mathcal{N}(0,T^2I) 来初始化经验PF ODE,并用任意数值求解器,如Euler和Heun求解器,进行逆时间求解,以获得解轨迹 {x^t}t[0,T]\lbrace\hat{x}_t\rbrace_{t\in[0,T]} 。结果 x^0\hat{x}_0 可以视为来自数据分布 pdata(x)p_{data}(x) 的近似样本。为避免数值不稳定,通常在 t=ϵt = \epsilon 停止求解器,其中, ϵ\epsilon是一个固定的小正数,并接受 x^ϵ\hat{x}_{\epsilon} 作为近似样本。遵循EDM,我们将图像像素值调整到[-1,1],并设置T = 80, ϵ=0.002\epsilon = 0.002

扩散模型的瓶颈在于缓慢的采样速度。显然,使用ODE求解器进行采样需要分数模型 sϕ(x,t)s_{\phi}(x,t) 的迭代估计,其计算成本高。现有快速采样方法包括快速数值求解器和蒸馏技术。然而,ODE求解器仍需超过10估计步骤来生成完整样本。大多数蒸馏方法,如Luhmanm,2021Zheng等,2022,在蒸馏前依赖于收集来自扩散模型的大规模的样本数据集本身计算昂贵。据我们所知,唯一没有该缺点的蒸馏方法是渐进蒸馏(PD),我们在实验中将其与一致性模型大量对比。

一致性模型

一致性模型是一类核心设计支持单步生成的模型,同时仍允许迭代生成,以权衡采样质量和计算,和零样本数据编辑。一致性模型可以用蒸馏模式或独立模式进行训练。前者,一致性模型蒸馏预训练扩散模型的知识到单步采样器,显著改进其他蒸馏方法的样本质量,同时允许零样本图像编辑应用。后者,一致性模型独立训练,不依赖预训练扩散模型。这使它们成为一类独立的新生成模型。

以下介绍一致性模型的定义、参数化和采样,以及简单讨论其在零样本数据编辑上的应用。
定义 给定方程(2)中PF ODE的解轨迹 xtt[ϵ,T]{x_t}_{t\in[\epsilon,T]},我们定义一致性函数为 f:(xt,t)xϵf : (x_t,t) \rightarrow x_{\epsilon} 。 一致性函数具有自我一致性的性质:对于属于同一PF ODE轨迹的任意对(x_t,t),其输出是一致的,即对于所有t, f(xt,t)=f(xt,t)f(x_t,t) = f(x_{t^{'}},t^{'}), t[ϵ,T]t \in [\epsilon,T]。如图2所示,一致性模型的目标,符号表示为 fθ\mathcal{f}_{\theta},是通过学习强制一致性属性(详见章节4和5)从数据中估计该一致性函数 f\mathcal{f} 。类似的定义用于神经ODEs中神经流(neural flows)。然而,相较于神经流,我们不强制一致性模型是可逆的。

一致性模型经过训练,可以将PF ODE的任何轨迹上的点映射到轨迹的原点。
一致性模型经过训练,可以将PF ODE的任何轨迹上的点映射到轨迹的原点

参数化 对于任意一致性模型 f(,)\mathcal{f}(\cdot,\cdot),我们有 f(xϵ,ϵ)=xϵ\mathcal{f}(x_{\epsilon,\epsilon}) = x_{\epsilon},即 f(,ϵ)f(\cdot,\epsilon) 是恒等函数。该限制称为边界条件。所有一致性模型必须满足该边界条件,它在一致性模型的成功训练中至关重要。该边界条件也是一致性模型中最受限制的架构约束。对于基于深度神经网络的一致性模型,讨论两种几乎免费实现该边界条件的方式。假设有一个自由形式的深度圣经网络 Fθ(x,t)F_{\theta}(x,t) ,其输出与x具有同样的维度。第一种方式是简单参数化一致性模型为:

fθ(x,t)={xt=ϵFθ(x,t)t(ϵ,T].(4)f_{\theta}(x,t) = \begin{cases} x \quad\quad\quad\quad\quad t = \epsilon \\ \mathrm{F}_{\theta}(x,t) \quad\quad t \in (\epsilon,T] \end{cases} \,. \tag{4}

第二种方法是使用跳跃连接(skip connections)参数化一致性模型,即

fθ(x,t)=cskip(t)x+cout(t)Fθ(x,t),(5)f_{\theta}(x,t) = c_{skip}(t)x + c_{out}(t)F_{\theta}(x,t)\,, \tag{5}

其中, cskip(x)c_{skip}(x)cout(t)c_{out}(t) 是可微函数,使得 cskip(ϵ)=1c_{skip}(\epsilon) = 1 ,和 cout(ϵ)=0c_{out}(\epsilon) = 0 。如此,如果 Fθ(x,t)F_{\theta}(x,t)cskip(t)c_{skip}(t)cout(t)c_{out}(t) 全部可微,一致性模型在 t=ϵt = \epsilon 是可微的,这对于训练一致性模型至关重要(附录B.1和附录B.2)。方程(5)中的参数化与许多成功的扩散模型非常相似,使得它较为容易借助强大的扩散模型架构来构建一致性模型。因此,我们在所有实验中遵循第二种参数化。

采样 使用训练良好的一致性模型 fθ(,)\mathcal{f}_{\theta}(\cdot,\cdot) ,我们可以通过从初始分布 x^TN(0,T2I)\hat{x}_T \sim \mathcal{N}(0,T^2I) 中采样生成样本,然后评估 x^ϵ=fθ(x^T,T)\hat{x}_{\epsilon} = \mathcal{f}_{\theta}(\hat{x}_T,T) 的一致性模型。这涉及仅一次一致性模型前向传播,因此一步生成样本。重要的是,可以通过交替地去噪和噪声注入步骤多次估计一致性模型来改进样本质量。在算法1中总结,该多部采样过程提供权衡计算和样本质量的灵活性。在实际操作中,我们使用贪婪算法在算法 1 中寻找时间点 {T1,T2,,TN1}\lbrace\mathcal{T}_1, \mathcal{T}_2, \cdot\cdot\cdot,\mathcal{T}_{N-1}\rbrace ,即使用三元搜索法逐个精确定位时间点,以优化算法 1 所获样本的 FID。这假设给定之前的时间点,FID是下一个时间点的单峰函数。我们发现这一假设在我们的实验中是成立的,并将探索更好的策略作为未来的工作。

输入:一致性模型,时间点序列,初始噪声
过程:初始噪声和时间步传入一致性模型得到初步生成x, 通过采样高斯向量计算每步噪声量更新前一步生成的x,输入一致性模型获得更新后的再次生成。如此循环更新N-1次,获得最终的输出x。

多步一致性采样

零样本数据编辑 类似于扩散模型,一致性模型能够进行各种零样本数据编辑和操作应用,无需显示训练执行这些任务。例如,一致性模型定义来自高斯噪声向量到数据样本的一对一映射。类似于潜变量模型,如GANs、VAEs和归一化流,一致性模型可以通过便利潜在空间轻容在样本间插值(图11)。由于一致性模型训练用于从任意噪声输入 xtx_t 中恢复 xϵx_{\epsilon} ,其中 t[ϵ,T]t \in[\epsilon,T] ,它们可以执行各种噪声水平的去噪(图12)。此外,算法1中的多步生成过程,通过类似于扩散模型的迭代替换过程,对于解决特定零样本逆向问题很有用。这使得图像编辑中的许多应用成为可能,包括恢复、着色、超分辨率和笔画引导的图像编辑。在第6.3节中,我们通过经验证明了一致性模型在许多零样本图像编辑任务中的威力。

通过蒸馏训练一致性模型

基于蒸馏预训练分数模型 sϕ(x,t)s_{\phi}(x,t) 训练一致性模型。讨论围绕方程(3)中的经验PF ODE展开,该方程是通过将分数模型 sϕ(x,t)s_{\phi}(x,t) 代入PF ODE中获得的。考虑将时间范围 [ϵ,T][\epsilon,T] 离散化为N-1个子区间,边界为 t1=ϵ<t2<<tN=Tt_1 = \epsilon \lt t_2 \lt \cdot\cdot\cdot \lt t_N = T。实际中,我们遵循EDM用公式 ti=(ϵ1/ρ+i1/N1(T1/ρϵ1/ρ))ρt_i = (\epsilon^{1/\rho} + i-1/N-1(T^{1/\rho - \epsilon^{1/\rho}}))^{\rho} 来确定边界,其中 ρ\rho = 7。当N足够大,我们可以通过运行一次离散步数值ODE求解器获得来自 xtn+1x_{t_{n+1}}xtnx_{t_n} 的准确估计。该估计,表示为 x^tnϕ\hat{x}_{t_n}^{\phi} ,定义为:

x^tnϕ:=xtn+1+(tntn+1)ϕ(xtn+1,tn+1;ϕ),(6)\hat{x}_{t_n}^{\phi} := x_{t_{n+1}} + (t_n - t_{n+1})\phi(x_{t_{n+1}},t_{n+1};\phi)\,, \tag{6}

其中, ϕ(;ϕ)\phi(\cdot\cdot\cdot;\phi) 表示应用于经验PF ODE的一步ODE求解器的更新函数。例如,当使用Eluer求解器,有 ϕ(x,t;ϕ)=tsϕ(x,t)\phi(x,t;\phi) = -ts_{\phi}(x,t) ,其对应于如下更新规则:

x^tnϕ=xtn+1(tntn+1)tn+1sϕ(xtn+1,tn+1).\hat{x}_{t_n}^{\phi} = x_{t_{n+1}} - (t_n - t_{n+1})t_{n+1} s_{\phi}(x_{t_{n+1}},t_{n+1}) \,.

为了简化,本文中仅考虑一步ODE求解器。将我们的框架推广到多步ODE求解器是很简单的,我们将其作为未来的工作。

由于方程(2)中的PF ODE和方程(1)中的SDE之间的联系(见第2节),可以通过首先采样 xpdatax\sim p_{data} ,然后将高斯噪声添加到x来沿着ODE轨迹的分布进行采样。具体地,给定数据点x,可以通过从数据集采样x在PF ODE轨迹上高效生成一对相邻数据点 (x^tnϕ,xtn+1)(\hat{x}_{t_n}^{\phi},x_{t_{n+1}}) ,接着从 SDE N(x,tn+12I)\mathcal{N}(x, t_{n+1}^2I) 的转换密度采样 xtn+1x_{t_{n+1}},然后根据方程(6)使用一次离散步数值ODE求解器计算 x^tnϕ\hat{x}_{t_n}^{\phi}。然而,通过最小化它在对 (x^tnϕ,xtn+1)(\hat{x}_{t_n}^{\phi},x_{t_{n+1}}) 上的输出差异训练一致性模型。因此,我们提出了以下用于训练一致性模型的一致性蒸馏损失。

定义 1 一致性蒸馏损失定义为:

LCDN(θ,θ;ϕ):=E[λ(tn)d(fθ(xtn+1,tn+1),fθ(x^tnϕ,tn))],(7)\mathcal{L}_{CD}^{N}(\theta,\theta^-;\phi) := E[\lambda(t_n)d(\mathcal{f}_{\theta}(x_{t_{n+1}},t_{n+1}), \mathcal{f}_{\theta^-}(\hat{x}_{t_n}^{\phi},t_n))]\,, \tag{7}

其中,期望值是相对于 xpdatax \sim p_{data}nU[1,N1]n \sim \mathcal{U}[1,N-1]xtn+1N(x;tn+12I)x_{t_{n+1}} \sim \mathcal{N}(x;t_{n+1}^2 I) 取的。这里, U[1,N1]\mathcal{U}[1,N-1]表示 {1,2,,N1}\lbrace 1,2,\cdot\cdot\cdot, N - 1\rbrace 上的均匀分布, λ()R+\lambda(\cdot) \in R^{+} 是正加权函数, x^tnϕ\hat{x}_{t_n}^{\phi} 由方程(6)给出, θ\theta^- 表示优化过程中 θ\theta 历史值的连续平均值, 且 d(,)d(\cdot,\cdot) 是度量函数,其满足 x,y:d(x,y)0\forall_{x,y} : d(x,y) \ge 0 ,当且仅当 x = y时, d(x,y)=0d(x,y) = 0

除非另外声明,我们采用定义1中的符号贯穿全文,并且使用 E[]E[\cdot] 表示所有随机变量之上的期望。实验中,考虑平方 l2\mathcal{l}_2 距离 d(x,y)=xy22d(x,y) = ||x - y||_2^2 、 \mathcal{l}_ 距离 d(x,y)=xy1d(x,y) = ||x - y||_1 和 学习感知图像补丁相似性(LPIPS, Learned Perceptual Image Patch Similarity)。我们发现, λ(tn)1\lambda(t_n) \equiv 1 在所有任务和数据集上执行良好。实际中,我们通过在随机梯度下降模型参数 θ\theta最小化目标,同时用指数移动平均(EMA)更新 θ\theta^- 。即,给定衰减率 0μ<10 \le \mu \lt 1 ,在每个优化步,我们执行如下更新:

θstopgrad(μθ+(1μ)θ).(8)\theta^- \leftarrow \mathcal{stopgrad}(\mu\theta^- + (1 - \mu)\theta)\,. \tag{8}

上述训练过程总结在算法2。为符合深度强化学习(DQNDDPG)和基于动量的对比学习(BYOLMoCo),我们称 fθ\mathcal{f}_{\theta^-} 为“目标网络”, fθ\mathcal{f}_{\theta} 为“在线网络”。我们发现,与简单设置 θ=θ\theta = \theta^- 相比,方程(8)中的EMA更新和“stopgrad”算子可以大大稳定训练过程,提高一致性模型的最终性能。

输入:数据集D,初始模型参数 θ\theta ,学习率 η\eta , ODE求解器 \Phi(\cdot,\codt;\phi) ,d(,)d(\cdot,\cdot)λ()\lambda(\cdot)μ\mu
一致性蒸馏

下面,我们基于渐近分析为一致性蒸馏提供了理论依据。

定理 1t:=maxn[1,N1]{tn+1tn}\triangle t := \mathcal{max}_{n\in[1,N-1]}\lbrace|t_{n+1} - t_n|\rbrace , 且 f(,;ϕ)\mathcal{f}(\cdot,\cdot;\phi) 为方程(3)中经验PF ODE的一致性函数。假设 fθ\mathcal{f}_{\theta} 满足李普希茨条件(Lipschitz condition):存在 L<0L \lt 0 使得对于全部 t[ϵ,T]t \in [\epsilon,T] , x 和y,有 fθ(x,t)fθ(y,t)2Lxy2||\mathcal{f}_{\theta}(x,t) - \mathcal{f}_{\theta}(y,t)||_2 \le L||x - y||_2 。进一步假设,对于所有 n[1,N1]n\in[1,N-1] ,ODE求解器在 tn+1t_{n+1}处调用时具有被 O((tn+1tn)p+1)O((t_{n+1} - t_n)^{p+1}) 均匀界定的局部误差,其中 p1p \ge 1。那么,若 LCDN(θ,theta;ϕ)=0\mathcal{L}_{CD}^{N}(\theta,theta;\phi) = 0 ,有

supn,xfθ(x,tn)f(x,tn;ϕ)2=O((t)p).\underset{n,x}{\mathcal{sup}}||\mathcal{f}_{\theta}(x,t_n) - \mathcal{f}(x,t_n;\phi)||_2 = O((\triangle t)^p)\,.

证明 该证明基于归纳法,与数值常微分方程求解器全局误差界的经典证明相似。我们在附录A.2中提供了完整的证明。

由于 θ\theta^-θ\theta 历史值的连续平均值,当算法2的优化收敛时,有 θ=θ\theta^- = \theta 。即,目标和在线一致性模型将最终相互匹配。如果一致性模型额外实现了零一致性蒸馏损失,那么定理1意味着,在某些规律性条件下,只要ODE求解器的步长足够小,估计的一致性模型就可以变得任意精确。重要的是,边界条件 fθ(x,ϵ)0\mathcal{f}_{\theta}(x,\epsilon)\equiv 0 排除了一致性模型训练中出现平凡解 fθ(x,t)0\mathcal{f}_{\theta}(x,t) \equiv 0

一致性蒸馏损失 LCDN(θ,θ;ϕ)\mathcal{L}_{CD}^{N}(\theta,\theta^-;\phi) 可以扩展为无限多个时间步 (N)(N \rightarrow \infty) ,若 θ=θ\theta^- = \thetaθ=stopgrad(θ)\theta^- = \mathcal{stopgrad}(\theta) 。由此产生的连续时间损失函数不需要指定N或时间步长 {t1,t2,,tN}\lbrace t_1, t_2, \cdot\cdot\cdot, t_N \rbrace 。尽管如此,它们涉及雅可比向量积,需要前向模式自动微分才能有效实现,这在一些深度学习框架中可能得不到很好的支持。我们在定理3至5中提供了这些连续时间蒸馏损失函数,并将细节放在附录B.1中。

独立训练一致性模型

一致性模型可以无需依赖任何预训练扩散模型进行训练。这与现有扩散模型技术不同,使得一致性模型称为新的独立系列的生成模型。

回顾一致性蒸馏中,依赖预训练分数模型 sϕ(x,t)s_{\phi}(x,t) 来近似真实分数函数 logpt(x)\triangledown\mathcal{log}p_t(x) 。事实证明,通过利用以下无偏估计量(附录A中的引理1),我们可以完全避免这种预训练的分数模型:

logpt(xt)=E[xtxt2xt],\triangle \mathcal{log} p_t(x_t) = - E[\frac{x_t - x}{t^2}|x_t]\,,

其中, xpdatax \sim p_{data}xtN(x;t2I)x_t \sim \mathcal{N}(x;t^2 I) 。即,给定x和 xtx_t ,可以用 (xtx)/t2-(x_t - x)/t^2 来估计 logpt(xt)\triangle \mathcal{log} p_t(x_t)
NN \rightarrow \infty 的极限下, 当使用Euler方法作为ODE求解器时,该无偏估计足够替换一致性蒸馏中的预训练扩散模型,由以下结果证明。

定理 2t:=maxn[1,N1]{tn+1tn}\triangle t := max_{n \in [1,N-1]}\lbrace |t_{n+1} - t_n|\rbrace 。假设d和 fθ\mathcal{f}_{\theta^-} 均为有界二阶导数的二次连续可微,权重函数 λ(cdot)\lambda(cdot) 是有界的,且 E[logptn(xtn)22]<E[||\triangledown\mathcal{log}p_{t_n}(x_{t_n})||_2^2] \lt \infty 。进一步假设我们使用Euler求解器,且预训练分数模型匹配真实值,即 t[ϵ,T]:sϕ(x,t)logpt(x)\forall t\in[\epsilon,T] : s_{\phi}(x,t) \equiv \triangledown\mathcal{log}p_t(x) 。那么:

LCDN(θ,θ;ϕ)=LCDN(θ,θ)+o(t),(9)\mathcal{L}_{CD}^{N}(\theta,\theta^-;\phi) = \mathcal{L}_{CD}^{N}(\theta,\theta^-) + o(\triangle t)\,, \tag{9}

其中,期望值根据 xpdatax \sim p_{data}nU[1,N1]n \sim \mathcal{U}[1,N - 1] ,且 xtn+1N(x;tn+12I)x_{t_{n+1}} \sim \mathcal{N}(x;t_{n+1}^2 I) 取得。一致性训练目标,表示为 LCTN(θ,theta)\mathcal{L}_{CT}^{N}(\theta,theta^-) ,定义为:

E[λ(tn)d(f(x+tn+1z,tn+1),fθ(x+tnz,tn))],(10)E[\lambda(t_n)d(\mathcal{f}(x + t_{n+1}z,t_{n+1}), f_{\theta^-}(x + t_nz, t_n))]\,, \tag{10}

其中, zN(0,I)z \sim \mathcal{N}(0,I) 。此外,若 infNLCDN(θ,θ;ϕ)>0inf_N\mathcal{L}_{CD}^{N}(\theta,\theta^-;\phi) \gt 0LCTN(θ,θ)O(t)\mathcal{L}_{CT}^{N}(\theta,\theta^-) \ge O(\triangle t)

证明 该证明基于泰勒级数展开和分数函数的性质(引理1)。附录A.3提供了完整的证明。

我们将方程(10)成为i一致性训练(CT)损失。至关重要的是, L(θ,θ)\mathcal{L}(\theta,\theta^-) 仅取决于在线网络 \mathcal{f|_{\theta} 和目标网络 fθ\mathcal{f}_{\theta^-} ,而与扩散模型参数 ϕ\phi 完全无关。损失函数 L(θ,θ)O(t)\mathcal{L}(\theta,\theta^-) \ge O(\triangle t) 的下降速度比余数 o(t)o(\triangle t) 慢,因此将在方程(9)中主导损失,当 NN \rightarrow \infty ,且 t0\triangle t \rightarrow 0

为提高实际性能,我们提出在训练中根据调度函数 N()N(\cdot) 渐进增加N。直觉(参见图3d)是,当N很小(即 t\triangle t 很大)时,一致性训练损失相对于基础一致性蒸馏损失(即方程(9)的左侧)具有较小的“方差”,但具有较大的“偏差”,这有助于在训练开始时更快地收敛。相反,当N很大时(即 t\triangle t 很小),它有较大的“方差”,但较小的“偏差”,这在接近训练结束时是理想的。为了获得最佳性能,我们还发现,根据调度函数 μ()\mu(\cdot)μ\mu 应该随着N的变化而变化。算法3提供了一致性训练的完整算法,附录C给出了我们实验中使用的调度函数。

一致性训练

类似于一致性蒸馏,一致性训练损失 LCTN(θ,theta)\mathcal{L}_{CT}^{N}(\theta,theta^-) 可以扩展到连续时间(即 NN \rightarrow \infty ),若 θ=stopgrad(θ)\theta^- = \mathcal{stopgrad}(\theta) ,如定理6所示。该连续时间损失函数不需要N或 μ\mu 的调度函数,但需要前向模式自动微分进行高效实现。不同于离散时间CT损失,不存在与连续时间目标相关的不良“偏差”,因为我们在定理2中有效地取 t0\triangle t \rightarrow 0 。我们将更多细节放在附录B.2中。

实验

我们采用一致性蒸馏和一致性训练来学习真实图像数据集上的一致性模型,包括CIFAR-10(Krizhevsky等人,2009)、ImageNet 64ˆ64(Deng等人,2009年)、LSUN Bedroom 256ˆ; 256和LSUN Cat 256 \710 256(Yu等人,2015年)。根据Frechet Inception Distance(FID,Heusel等人(2017),越低越好)、Inception Score(is,Salimans等人(2016),越高越好)、Precision(Prec.,Kynka¨anniemi等人(2019),越高越好)和Recall([Rec.,Kynka’anniemi等人(2019)](Improved precision and recall metric for assessing generative models),越高越好)对结果进行比较。附录C中提供了其他实验细节。

训练一致性模型

我们在CIFAR-10上执行一系列实验来理解各种超参数对由一致性蒸馏(CD)和一致性训练(CT)训练的一致性模型的性能的影响。首先关注度量函数 d()d(\cdot) 、ODE求解器,和CD中的离散步N,然后研究调度函数 N()N(\cdot)μ()\mu(\cdot) 在CT中的影响。

为设置CD的实验,考虑:
度量函数:平方 l2\mathcal{l}_2 距离 d(x,y)=xy22d(x,y) = ||x - y||_2^2 , l1\mathcal{l}_1 距离 d(x,y)=xy1d(x,y) = ||x - y||_1 , 和学习感知图像补丁相似性(LPIPS)。
ODE求解器:比较Euler的前向方法和Heun的二阶方法,如EDM所述。
离散步N:对比 N{9,12,18,36,50,60,80,120}N \in \lbrace 9,12,18,36,50,60,80,120\rbrace
实验中,由CD训练的所有一致性模型均由相应的预训练扩散模型初始化,而由CT训练的模型则时随机初始化。

所图3a所示,CD的最优度量指标是LPIPS,在所有训练迭代中,其性能都大大优于 l1\mathcal{l}_1l2\mathcal{l}_2 。这是意料之中的,因为一致性模型的输出是CIFAR-10上的图像,而LPIPS是专门为测量自然图像之间的相似性而设计的。接下来,我们研究哪种ODE求解器和哪种离散化步骤N对CD最有效。如图3b和3c所示,Heun ODE求解器和N=18是最佳选择。尽管我们正在训练一致性模型,而不是扩散模型,但两者都符合EDM的建议。此外,图3b显示,在相同的N下,Heun的二阶求解器的性能均匀优于Euler的一阶求解器。这与定理1相吻合,定理1指出,在相同的N下,由高阶ODE求解器训练的最优一致性模型具有较小的估计误差。图3c的结果还表明,一旦N足够大,CD的性能对N就变得不敏感。鉴于这些见解,除非另有说明,否则我们将在下文中使用LPIPS和Heun ODE求解器求解CD。对于CD中的N,我们遵循EDM中在CIFAR-10和ImageNet 64x64上的建议。我们在其他数据集上分别调整N(详见附录C)

由于CD和CT之间的紧密联系,本文中我们采用LPIPS进行CT实验。与CD不同,在CT中不需要使用Heun的二阶求解器,因为损失函数不依赖于任何特定的数值ODE求解器。如图3d所示,CT的收敛对N高度敏感——较小的N导致较快的收敛,但采样较差,而较大的N收敛较慢,但一旦收敛,样本较好。这与章节5中我们的分析一致,并促使我们实际选择逐步增加CT的N和µ,以平衡收敛速度和样本质量之间的权衡。如图3d所示,N和µ的自适应调度显著提高了CT的收敛速度和样本质量。在我们的实验中,我们针对不同分辨率的图像分别调整了 N()N(\cdot)μ()\mu(\cdot) 的调度,更多细节见附录C。

影响CIFAR-10一致性蒸馏(CD)和一致性训练(CT)的各种因素。CD的最佳配置是LPIPS、Heun ODE求解器和N=18。我们针对N和µ的自适应调度函数使CT的收敛速度明显快于在优化过程中将其固定为常数。

几步图像生成

蒸馏 在目前的文献中,与我们的一致性蒸馏(CD)最直接可比的方法是渐进蒸馏(PD,Salimans&Ho(2022));到目前为止,这两种方法都是唯一在蒸馏前不构建合成数据的蒸馏方法。与之形成鲜明对比的是,其他蒸馏技术,如知识蒸馏DFNO,必须通过昂贵的数值ODE/SDE求解器从扩散模型中生成大量样本来准备大型合成数据集。我们在CIFAR-10、ImageNet 64和LSUN 256上对PD和CD进行了全面比较,所有结果如图4所示。所有方法都是从我们内部预训练的EDM模型中提炼出来的。我们注意到,在所有采样迭代中,与Salimans&Ho(2022)的原始论文中的平方ℓ2距离相比,使用LPIPS度量可以均匀地提高PD。随着我们采取更多的采样步骤,PD和CD都有所改善。我们发现,在所有考虑的数据集、采样步骤和度量函数中,CD的性能均优于PD,但Vedriin 256×256256 \times 256 上的单步生成除外,其中ℓ2的CD性能略低于ℓ2的PD。如表1所示,CD甚至优于需要合成数据集构建的蒸馏方法,如知识蒸馏DFNO

直接生成 在表1和表2中,我们比较了一致性训练(CT)与使用一步和两步生成的其他生成模型的样本质量。我们还包括PD和CD结果以供参考。这两个表都报告了从ℓ2度量函数获得的PD结果,因为这是PD,Salimans&Ho(2022)原始论文中使用的默认设置。为了公平比较,我们确保PD和CD提取相同的EDM模型。在表1和表2中,我们观察到CT在CIFAR-10上的表现明显优于现有的单步、非对抗性生成模型,即VAE和归一化流。此外,CT在不依赖蒸馏的情况下实现了与PD一步样品相当的质量。在图5中,我们提供了EDM样本(顶部)、单步CT样本(中间)和两步CT样本(底部)。在附录E中,我们在图14至21中显示了CD和CT的其他样本。重要的是,从相同的初始噪声向量中获得的所有样本都具有显著的结构相似性,即使CT和EDM模型是彼此独立训练的。这表明CT不太可能像EDM那样遭受模式崩溃。

EDM(顶部)、CT+单步生成(中间)和CT+两步生成(底部)生成的样本。所有相应的图像都是由相同的初始噪声生成的。

零样本图像编辑

与扩散模型类似,一致性模型允许通过修改算法1中的多步采样过程进行零样本图像编辑。我们使用一致性蒸馏在LSUN卧室数据集上训练的一致性模型证明了这一能力。在图6a中,我们展示了这种一致性模型可以在测试时对灰度卧室图像进行着色,即使它从未在着色任务上进行过训练。在图6b中,我们展示了相同的一致性模型可以从低分辨率输入生成高分辨率图像。在图6c中,我们还证明了它可以根据人类创建的笔划输入生成图像,就像扩散模型的SDEdit一样(Meng等人,2021)。同样,这种编辑功能是零样本,因为模型尚未在笔划输入上进行训练。在附录D中,我们还演示了零样本一致性模型在修复(图10)、插值(图11)和去噪(图12)方面的能力,以及更多关于彩色化(图8)、超分辨率(图9)和笔划引导图像生成(图13)的示例。

零样本图像编辑

结论

我们引入了一致性模型,这是一种专门用于支持单步和少步生成的生成模型。我们已经实证证明,我们的一致性蒸馏方法在多个图像基准和小采样迭代上优于现有的扩散模型蒸馏技术。此外,作为一个独立的生成模型,一致性模型比现有的单步生成模型(GAN除外)生成更好的样本。与扩散模型类似,它们还允许零样本图像编辑应用,如修复、着色、超分辨率、去噪、插值和笔划引导图像生成。

此外,一致性模型与其他领域采用的技术有着惊人的相似之处,包括深度Q学习(Mnih等人,2015)和基于动量的对比学习(Grill等人,2020He等人,2020)。这为这些不同领域的思想和方法的异花授粉提供了令人兴奋的前景。