连续归一化流 (CNF) 模型通过一个时间依赖的微分同胚映射将简单分布(如标准正态分布)转化为复杂数据分布。要计算 CNF 模型在任意数据点 x1 的概率,我们需要利用连续性方程和流的轨迹方程。
连续性方程和概率密度的变化
连续性方程描述了概率密度的守恒特性:
dtdpt(x)+div(pt(x)vt(x))=0
其中:
- pt(x) 是时间 t 的概率密度函数。
- vt(x) 是时间 t 的向量场,由 CNF 模型学习得到。
- div 是散度算子。
这个方程表明,概率密度的变化率必须与其“流动”情况相互抵消,以保证总概率守恒。
计算概率密度的步骤
步骤 1: 瞬时变量的变化
将连续性方程与流的轨迹方程 (dtdϕt(x)=vt(ϕt(x))) 结合,可以得到以下瞬时变量的变化:
dtdlogpt(ϕt(x))+div(vt(ϕt(x))=0
步骤 2:积分计算对数概率
对上述方程在时间区间[0,1]上积分,得到:
logp1(ϕ1(x))−logp0(ϕ0(x))=−∫01div(vt(ϕt(x)))dt
步骤 3:求解常微分方程
为了计算 logp1(x1),我们需要求解以下常微分方程 (ODE):
dtd[ϕt(x)f(t)]=[vt(ϕt(x))−div(vt(ϕt(x)))]
其中 f(t) 是一个辅助变量,用于记录 logpt(ϕt(x)) 的变化。初始条件为:
[ϕ0(x)f(0)]=[x0c]
其中 x0 是初始分布 p0(x)中的样本,c 是一个常数。
步骤 4:计算最终概率
求解上述 ODE 后,我们可以得到 f(1)。根据步骤 2 中的积分公式,我们可以得到:
f(1)=c+logp1(x1)−logp0(x0)
因此,我们可以计算出 CNF 模型在数据点 x1 的概率密度:
p1(x1)=exp(f(1)−c+logp0(x0))
总结
通过利用连续性方程和流的轨迹方程,我们可以推导出计算 CNF 模型在任意数据点 x1 概率密度的方法。这个方法需要求解一个常微分方程,并利用初始条件和积分公式来计算最终的概率密度。