连续归一化流 (CNF) 模型通过一个时间依赖的微分同胚映射将简单分布(如标准正态分布)转化为复杂数据分布。要计算 CNF 模型在任意数据点 x1x_1 的概率,我们需要利用连续性方程和流的轨迹方程。

连续性方程和概率密度的变化

连续性方程描述了概率密度的守恒特性:

ddtpt(x)+div(pt(x)vt(x))=0\frac{d}{dt}p_t(x) + div(p_t(x)v_t(x)) = 0

其中:

  • pt(x)p_t(x) 是时间 t 的概率密度函数。
  • vt(x)v_t(x) 是时间 t 的向量场,由 CNF 模型学习得到。
  • divdiv 是散度算子。

这个方程表明,概率密度的变化率必须与其“流动”情况相互抵消,以保证总概率守恒。

计算概率密度的步骤

步骤 1: 瞬时变量的变化

将连续性方程与流的轨迹方程 (ddtϕt(x)=vt(ϕt(x))\frac{d}{dt}\phi_t(x) = v_t(\phi_t(x))) 结合,可以得到以下瞬时变量的变化:

ddtlogpt(ϕt(x))+div(vt(ϕt(x))=0\frac{d}{dt}\log p_t(\phi_t(x)) + div(v_t(\phi_t(x)) = 0

步骤 2:积分计算对数概率

对上述方程在时间区间[0,1]上积分,得到:

logp1(ϕ1(x))logp0(ϕ0(x))=01div(vt(ϕt(x)))dt\log p_1(\phi_1(x)) - \log p_0(\phi_0(x)) = -\int^1_0 div(v_t(\phi_t(x)))dt

步骤 3:求解常微分方程

为了计算 logp1(x1)\log p_1(x_1),我们需要求解以下常微分方程 (ODE):

ddt[ϕt(x)f(t)]=[vt(ϕt(x))div(vt(ϕt(x)))]\frac{d}{dt}\begin{bmatrix}\phi_t(x)\cr f(t)\end{bmatrix} = \begin{bmatrix}v_t(\phi_t(x))\cr -div(v_t(\phi_t(x)))\end{bmatrix}

其中 f(t)f(t) 是一个辅助变量,用于记录 logpt(ϕt(x))\log p_t(\phi_t(x)) 的变化。初始条件为:

[ϕ0(x)f(0)]=[x0c]\begin{bmatrix}\phi_0(x)\cr f(0)\end{bmatrix} = \begin{bmatrix}x_0\cr c\end{bmatrix}

其中 x0x_0 是初始分布 p0(x)p_0(x)中的样本,cc 是一个常数。

步骤 4:计算最终概率

求解上述 ODE 后,我们可以得到 f(1)f(1)。根据步骤 2 中的积分公式,我们可以得到:

f(1)=c+logp1(x1)logp0(x0)f(1) = c + \log p_1(x_1) - \log p_0(x_0)

因此,我们可以计算出 CNF 模型在数据点 x1x_1 的概率密度:

p1(x1)=exp(f(1)c+logp0(x0))p_1(x_1) = exp(f(1) - c + \log p_0(x_0))

总结

通过利用连续性方程和流的轨迹方程,我们可以推导出计算 CNF 模型在任意数据点 x1 概率密度的方法。这个方法需要求解一个常微分方程,并利用初始条件和积分公式来计算最终的概率密度。