『扩散模型』一篇文章入门随机微分方程SDE

随机微分方程SDE

笔者建议,学完DDPM再来看SDE的作用和推导过程

标准布朗运动

在学习随机微分方程之前,我们先来看一下什么是标准布朗运动
假设有一个一维的直线,有个小人从原点出发,每次随机地选择向左走1格或者向右走1格,且向左走和向右走的两个选项,被选择的概率相等 → \rightarrow S t S_t St代表小人离原点的距离, t t t代表代表选择的次数,如果选择的次数越多,那么 S t S_t St将会逐渐服从一个均值为0、方差为 t t t的正态分布
布朗运动 W ( t ) W(t) W(t)是期望为0、方差为 t t t的正态分布 ⇔ \Leftrightarrow W t ∼ N ( 0 , t ) W_t\sim \mathcal{N}(0,t) WtN(0,t) ⇒ \Rightarrow W t + Δ t − W t ∼ N ( 0 , Δ t ) W_{t+\Delta t}-W_t\sim \mathcal{N}(0,\Delta t) Wt+ΔtWtN(0,Δt),当 Δ t → 0 \Delta t\rightarrow 0 Δt0时, d w = d t ε dw=\sqrt{dt}\varepsilon dw=dt ε(重参数技巧)

SDE加噪

在DDPM中,扩散过程被划分为固定的T步 ⇒ \Rightarrow DDPM=拆楼+建楼 ⇒ \Rightarrow “拆楼”和“建楼”都被事先划分为了T步,这个划分有着相当大的人为性。事实上,真实的“拆”、“建”过程应该是没有刻意划分的步骤 ⇒ \Rightarrow 可以将它们理解为一个在时间上连续的变换过程,可以用随机微分方程(Stochastic Differential Equation,SDE)来描述,即 d x = f t ( x ) d t + g t d w t d\boldsymbol{x}=\boldsymbol{f}_t(\boldsymbol{x})dt+g_td\boldsymbol{w_t} dx=ft(x)dt+gtdwt,其中 f t ( x t ) f_t(x_t) ft(xt)是漂移项,描述数据的确定性演化 g t g_t gt是扩散项,描述的是噪声的扩散程度 d w t dw_t dwt是维纳运动(布朗运动)的微小增量,表示随机波动
随机微分方程: d x = dx= dx=确定的变化 + + +随机的变化,其中随机的变化代表着随机性
随机微分方程描述了系统从 t t t时刻到 t + Δ t t+\Delta t t+Δt时刻的变化
我们可以将随机微分方程看成是 x t + Δ t − x t = f t ( x t ) Δ t + g t Δ t ε , ε ∼ N ( 0 , I ) \boldsymbol{x}_{t+\Delta t}-\boldsymbol{x}_t=\boldsymbol{f}_t(\boldsymbol{x}_t)\Delta t+g_t\sqrt{\Delta t}\boldsymbol{\varepsilon},\quad\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0},\boldsymbol{I}) xt+Δtxt=ft(xt)Δt+gtΔt ε,εN(0,I) Δ t → 0 \Delta t\rightarrow 0 Δt0时的极限 ⇒ \Rightarrow 如果建楼要1天,那么拆楼就是 x x x t = 0 t=0 t=0 t = 1 t=1 t=1时刻的变化
越小的步数 Δ t \Delta t Δt意味着对原始噪声越好的近似,如果 Δ t = 0.001 \Delta t=0.001 Δt=0.001,对应着 T = 1000 T=1000 T=1000;如果 Δ t = 0.01 \Delta t=0.01 Δt=0.01,则对应 T = 100 T=100 T=100(总时间步数 T T T是模拟的总时间跨度被步长 Δ t \Delta t Δt划分的次数 T = t max ⁡ Δ t T=\frac{t_{\max}}{\Delta t} T=Δttmax ⇒ \Rightarrow 引入SDE的本质好处是将理论分析和代码实现分离开来
DDPM的加噪过程本质上是一个SDE,而SDE本质上描述的是微小时间变化下系统状态的变化

  • DDPM的加噪: x t + 1 = 1 − β t x t + β t ϵ x_{t+1}=\sqrt{1-\beta_t}x_t+\sqrt{\beta_t}\epsilon xt+1=1βt xt+βt ϵ
  • SDE的加噪: d x = f t ( x ) d t + g t d w t d\boldsymbol{x}=\boldsymbol{f}_t(\boldsymbol{x})dt+g_td\boldsymbol{w_t} dx=ft(x)dt+gtdwt

在这里,笔者介绍一下将DDPM加噪公式映射到SDE加噪公式的推导过程:

  1. 重写DDPM加噪公式: x t + 1 − x t = ( 1 − β t − 1 ) x t + β t ϵ x_{t+1}-x_t=(\sqrt{1-\beta_t}-1)x_t+\sqrt{\beta_t}\epsilon xt+1xt=(1βt 1)xt+βt ϵ ⇒ \Rightarrow 1 − β t ≈ 1 − β t 2 \sqrt{1-\beta_t}\approx1-\frac{\beta_t}2 1βt 12βt ⇒ \Rightarrow 将DDPM加噪公式重新表示为一个确定项和随机噪声项的和: x t + 1 − x t ≈ − β t 2 x t + β t ϵ x_{t+1}-x_{t}\approx-\frac{\beta_{t}}{2}x_{t}+\sqrt{\beta_{t}}\epsilon xt+1xt2βtxt+βt ϵ
    在这里,使用泰勒展开得到 1 − β t ≈ 1 − β t 2 \sqrt{1-\beta_t}\approx1-\frac{\beta_t}2 1βt 12βt
    先来介绍一下泰勒展开:如果 f ( x ) f(x) f(x) x = a x=a x=a处是可微的,则它的泰勒展开可以写为 f ( x ) ≈ f ( a ) + f ′ ( a ) ( x − a ) + f ′ ′ ( a ) 2 ! ( x − a ) 2 + … f(x)\approx f(a)+f'(a)(x-a)+\frac{f''(a)}{2!}(x-a)^2+\ldots f(x)f(a)+f(a)(xa)+2!f′′(a)(xa)2+,其中 f ′ ( a ) f'(a) f(a) f ′ ′ ( a ) f''(a) f′′(a)分别是 f ( x ) f(x) f(x) a a a处的一阶导数和二阶导数;在泰勒展开中,若函数依赖多个变量,需要对每个变量分别进行展开
    f ( β t ) = 1 − β t f(\beta_t)=\sqrt{1-\beta_t} f(βt)=1βt β t = 0 \beta_t=0 βt=0处展开 ⇒ \Rightarrow 零阶项: f ( 0 ) = 1 − 0 = 1 f(0)=\sqrt{1-0}=1 f(0)=10 =1;一阶导数: f ′ ( β t ) = d d β t 1 − β t = − 1 2 1 − β t f'(\beta_t)=\frac{d}{d\beta_t}\sqrt{1-\beta_t}=\frac{-1}{2\sqrt{1-\beta_t}} f(βt)=dβtd1βt =21βt 1,在 β t = 0 \beta_t=0 βt=0 f ′ ( 0 ) = − 1 2 1 − 0 = − 1 2 f'(0)=\frac{-1}{2\sqrt{1-0}}=-\frac{1}{2} f(0)=210 1=21 ⇒ \Rightarrow f ( β t ) ≈ f ( 0 ) + f ′ ( 0 ) β t f(\beta_t)\approx f(0)+f'(0)\beta_t f(βt)f(0)+f(0)βt ⇒ \Rightarrow 1 − β t ≈ 1 − 1 2 β t \sqrt{1-\beta_t}\approx1-\frac{1}{2}\beta_t 1βt 121βt
  2. 引入 Δ t \Delta t Δt Δ t \Delta t Δt只是在数学上引入的时间增量,而 β t \beta_t βt在离散模型中的定义是独立于 Δ t \Delta t Δt的,将DDPM离散的加噪过程转换为连续时间的随机微分方程描述: x t + Δ t − x t ≈ − β t 2 x t Δ t + β t Δ t ϵ x_{t+\Delta t}-x_t\approx-\frac{\beta_t}{2}x_t\Delta t+\sqrt{\beta_t\Delta t}\epsilon xt+Δtxt2βtxtΔt+βtΔt ϵ ⇒ \Rightarrow d x = − 1 2 β t x t d t + β t d w dx=-\frac{1}{2}\beta_tx_tdt+\sqrt{\beta_t}dw dx=21βtxtdt+βt dw
  3. SDE的形式:漂移项 f t ( x t ) = − β t 2 x t f_t(x_t)=-\frac{\beta_t}2x_t ft(xt)=2βtxt,扩散系数 g t = β t g_{t}=\sqrt{\beta_{t}} gt=βt
    在这里插入图片描述
    左侧是数据分布,右侧是正态分布,t是连续时间

SDE去噪

SDE去噪的目标是求 p ( x t ∣ x t + Δ t ) p(x_t|x_{t+\Delta t}) p(xtxt+Δt)
已知: x t + Δ t x_{t+\Delta t} xt+Δt和前向SDE过程 p ( x t + Δ t ∣ x t ) p(x_{t+\Delta t}|x_t) p(xt+Δtxt) ⇒ \Rightarrow 贝叶斯公式: p ( x t ∣ x t + Δ t ) = p ( x t + Δ t ∣ x t ) p ( x t ) p ( x t + Δ t ) p(x_t|x_{t+\Delta t})=\frac{p(x_{t+\Delta t}|x_t)p(x_t)}{p(x_{t+\Delta t})} p(xtxt+Δt)=p(xt+Δt)p(xt+Δtxt)p(xt)
为了简化问题,尽可能使 p ( x t ∣ x t + Δ t ) p(x_t|x_{t+\Delta t}) p(xtxt+Δt)的分布满足正态分布

  1. x t + Δ t = x t + f t ( x t ) Δ t + g t Δ t ϵ x_{t+\Delta t}=x_t+f_t(x_t)\Delta t+g_t\sqrt{\Delta t}\epsilon xt+Δt=xt+ft(xt)Δt+gtΔt ϵ ⇒ \Rightarrow 根据重参数可得: x t + Δ t ∼ N ( x t + f t ( x t ) Δ t , g t 2 Δ t ) x_{t+\Delta t}\sim\mathcal{N}(x_t+f_t(x_t)\Delta t,g_t^2\Delta t) xt+ΔtN(xt+ft(xt)Δt,gt2Δt)
  2. 正态分布的概率密度函数: f ( x ) = 1 σ 2 π e x p ( − ( x − μ ) 2 2 σ 2 ) f(x)=\frac1{\sigma\sqrt{2\pi}}exp(-\frac{(x-\mu)^2}{2\sigma^2}) f(x)=σ2π 1exp(2σ2(xμ)2)
  3. p ( x t ∣ x t + Δ t ) = e x p ( − ( x t + Δ t − x t − f t ( x t ) Δ t ) 2 2 g t 2 Δ t + l o g p ( x t ) − l o g p ( x t + Δ t ) ) ( 1 ) \begin{aligned} p(x_t|x_{t+\Delta t}) &=exp(-\frac{(x_{t+\Delta t}-x_t-f_t(x_t)\Delta t)^2}{2g_t^2\Delta t}+logp(x_t)-logp(x_{t+\Delta t}))&&(1)\\ \end{aligned} p(xtxt+Δt)=exp(2gt2Δt(xt+Δtxtft(xt)Δt)2+logp(xt)logp(xt+Δt))(1)
  4. x t x_t xt处泰勒展开 l o g p ( x t + Δ t ) logp(x_{t+\Delta t}) logp(xt+Δt) log ⁡ p ( x t + Δ t , t + Δ t ) ≈ log ⁡ p ( x t , t ) + ( x t + Δ t − x t ) ∇ x log ⁡ p ( x t , t ) ⏟ 状杰变化的影响 + Δ t ⋅ ∇ t log ⁡ p ( x t , t ) ⏟ 时间推移的影响 \log p(x_{t+\Delta t},t+\Delta t)\approx\log p(x_t,t)+\underbrace{(x_{t+\Delta t}-x_t)\nabla_x\log p(x_t,t)}_\text{状杰变化的影响}+\underbrace{\Delta t\cdot\nabla_t\log p(x_t,t)}_\text{时间推移的影响} logp(xt+Δt,t+Δt)logp(xt,t)+状杰变化的影响 (xt+Δtxt)xlogp(xt,t)+时间推移的影响 Δttlogp(xt,t),其中状态梯度 ∇ x log ⁡ p ( x t ) \nabla_{x}\log p(x_{t}) xlogp(xt)描述概率密度在空间中的“漂移”趋势(例如粒子倾向于从高密度区向低密度区移动),时间梯度 ∇ t log ⁡ p ( x t ) \nabla_{t}\log p(x_{t}) tlogp(xt)描述概率密度随时间的整体衰减或增长
    在这里笔者介绍一下为什么会多出一项 ∇ t log ⁡ p ( x t ) \nabla_{t}\log p(x_{t}) tlogp(xt) p ( x t ) p(x_t) p(xt)实际上是“ t t t时刻随机变量等于 x t x_t xt的概率密度”, p ( x t + Δ t ) p(x_{t+\Delta t}) p(xt+Δt)实际上是“ t + Δ t t+\Delta t t+Δt时刻随机变量等于 x t + Δ t x_{t+\Delta t} xt+Δt的概率密度”,即 p ( x t ) p(x_t) p(xt)实际上同时是时间 t t t和状态变量 x t x_t xt的函数
    时间梯度项 ∇ t log ⁡ p ( x t ) \nabla_{t}\log p(x_{t}) tlogp(xt)的必要性: ∇ t log ⁡ p ( x t ) \nabla_{t}\log p(x_{t}) tlogp(xt)描述了概率密度的动态演化,即使状态 x t x_t xt不变,概率密度 p ( x t ) p(x_t) p(xt)也会随时间 t t t变化
  5. Δ t → 0 \Delta t\rightarrow 0 Δt0时, Δ 2 t = 0 \Delta^2 t=0 Δ2t=0
  6. ( 1 ) = p ( x t ∣ x t + Δ t ) = e x p ( − ( x t + Δ t − x t − ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) Δ t ) 2 2 g t 2 Δ t ) = e x p ( − ( x t − ( x t + Δ t − ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) Δ t ) ) 2 2 g t 2 Δ t ) ( 2 ) \begin{aligned} (1)=p(x_t|x_{t+\Delta t}) &=exp(-\frac{(x_{t+\Delta t}-x_t-(f_t(x_t)-g_t^2\nabla xlogp(x_t))\Delta t)^2}{2g_t^2\Delta t})\\ &=exp(-\frac{(x_t-(x_{t+\Delta t}-(f_{t}(x_t)-g_t^2\nabla xlogp(x_t))\Delta t))^2}{2g_t^2\Delta t})&&(2) \end{aligned} (1)=p(xtxt+Δt)=exp(2gt2Δt(xt+Δtxt(ft(xt)gt2xlogp(xt))Δt)2)=exp(2gt2Δt(xt(xt+Δt(ft(xt)gt2xlogp(xt))Δt))2)(2)
    x t + Δ t − x t x_{t+\Delta t}-x_t xt+Δtxt ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) Δ t (f_t(x_t)-g_t^2\nabla xlogp(x_t))\Delta t (ft(xt)gt2xlogp(xt))Δt分别看作一个整体, ( x t + Δ t − x t − ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) Δ t ) 2 2 g t 2 Δ t \frac{(x_{t+\Delta t}-x_t-(f_t(x_t)-g_t^2\nabla xlogp(x_t))\Delta t)^2}{2g_t^2\Delta t} 2gt2Δt(xt+Δtxt(ft(xt)gt2xlogp(xt))Δt)2开方后可以得到和 ( x t + Δ t − x t − f t ( x t ) Δ t ) 2 2 g t 2 Δ t + l o g p ( x t ) − l o g p ( x t + Δ t ) \frac{(x_{t+\Delta t}-x_t-f_t(x_t)\Delta t)^2}{2g_t^2\Delta t}+logp(x_t)-logp(x_{t+\Delta t}) 2gt2Δt(xt+Δtxtft(xt)Δt)2+logp(xt)logp(xt+Δt)相同的结果
  7. Δ t → 0 \Delta t\rightarrow 0 Δt0时, t + Δ t → t t+\Delta t\rightarrow t t+Δtt
  8. ( 2 ) = e x p ( − ( x t − ( x t + Δ t − ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t ) ) 2 2 g t + Δ t 2 Δ t ) \begin{aligned} (2)=exp(-\frac{(x_t-(x_{t+\Delta t}-(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t))^2}{2g_{t+\Delta t}^2\Delta t}) \end{aligned} (2)=exp(2gt+Δt2Δt(xt(xt+Δt(ft+Δt(xt+Δt)gt+Δt2xlogp(xt+Δt))Δt))2)
  9. 均值 μ = x t + Δ t − ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t \mu=x_{t+\Delta t}-(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t μ=xt+Δt(ft+Δt(xt+Δt)gt+Δt2xlogp(xt+Δt))Δt、方差 g t + Δ t 2 Δ t g_{t+\Delta t}^2\Delta t gt+Δt2Δt
  10. x t = x t + Δ t − ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t + g t + Δ t Δ t ϵ x_t=x_{t+\Delta t}-(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t+g_{t+\Delta t}\sqrt{\Delta t}\epsilon xt=xt+Δt(ft+Δt(xt+Δt)gt+Δt2xlogp(xt+Δt))Δt+gt+ΔtΔt ϵ
  11. Δ t → 0 \Delta t\rightarrow 0 Δt0时, Δ t → d t \Delta t\rightarrow dt Δtdt
  12. x t + Δ t − x t = ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t + g t + Δ t Δ t ϵ x_{t+\Delta t}-x_t=(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t+g_{t+\Delta t}\sqrt{\Delta t}\epsilon xt+Δtxt=(ft+Δt(xt+Δt)gt+Δt2xlogp(xt+Δt))Δt+gt+ΔtΔt ϵ
  13. d x t = ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) d t + g t d t ϵ = ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) d t + g t d w ˉ \begin{aligned} dx_t=(f_{t}(x_{t})-g_{t}^2\nabla xlogp(x_{t}))dt+g_{t}\sqrt{dt}\epsilon=(f_{t}(x_{t})-g_{t}^2\nabla xlogp(x_{t}))dt+g_{t}d\bar{w} \end{aligned} dxt=(ft(xt)gt2xlogp(xt))dt+gtdt ϵ=(ft(xt)gt2xlogp(xt))dt+gtdwˉ

Score Matching

Score Matching 是一种用于估计概率密度函数 p ( x ) p(x) p(x) 的方法,核心方法是直接估计概率密度的梯度(称为score)而非密度本身
通过最小化模型估计的score与真实score之间的差异,可以简介学习数据的分布
我们先来看一下 ∇ x t log ⁡ p ( x t ) \nabla_{x_t}\log p(\boldsymbol{x}_t) xtlogp(xt)的含义:对数概率密度函数 log ⁡ p ( x t ) \log p(x_{t}) logp(xt)关于 x t x_t xt的梯度方向指向概率密度增加最快的方向、梯度大小反映密度变化的速率 ⇒ \Rightarrow 使用一个 θ \theta θ参数化的概率分布 p θ p_{\theta} pθ模拟 p p p,通过学习参数 θ \theta θ使 p θ p_{\theta} pθ接近 p p p
我们可以将 p θ p_{\theta} pθ看成是由两部分组成的,分别是表示密度的函数 p θ ~ \tilde{p_{\theta}} pθ~、归一化因子 Z θ Z_{\theta} Zθ ⇒ \Rightarrow p θ ( x ) = p ~ θ ( x ) Z θ = p ~ θ ( x ) ∫ x ∈ X p ~ θ ( x ) d x p_\theta(x)=\frac{\tilde{p}_\theta(x)}{Z_\theta}=\frac{\tilde{p}_\theta(x)}{\int_{x\in X}\tilde{p}_\theta(x)dx} pθ(x)=Zθp~θ(x)=xXp~θ(x)dxp~θ(x),其中未归一化的概率密度函数 p θ ~ \tilde{p_{\theta}} pθ~给出某个数据点 x x x相对于其他数据点的可能性大小,但并不能给出直接用于表示 x x x发生的真实概率
目前,使用极大似然估计求解 θ \theta θ的问题:不知道归一化因子 Z θ Z_{\theta} Zθ的值
解决方法:

  1. 引入得分函数(score function):概率密度函数的梯度 ∇ x log ⁡ p θ ( x ) \nabla_x\log p_\theta(x) xlogpθ(x)
  2. p θ ( x ) p_{\theta}(x) pθ(x)通过 l o g log log拆分成两项 ∇ x log ⁡ p ~ θ ( x ) − ∇ x log ⁡ Z θ \nabla_{x}\log\tilde{p}_{\theta}(x)-\nabla_{x}\log Z_{\theta} xlogp~θ(x)xlogZθ ⇒ \Rightarrow 由于求解的是 x x x的梯度,所以可以直接消掉 ∇ x log ⁡ Z θ \nabla_{x}\log Z_{\theta} xlogZθ,因为 ∇ x log ⁡ Z θ \nabla_{x}\log Z_{\theta} xlogZθ x x x无关;同时 p θ ~ \tilde{p_{\theta}} pθ~不受“概率分布”的约束,可以使用神经网络作为 p θ ~ \tilde{p_{\theta}} pθ~,因为 p θ ~ \tilde{p_{\theta}} pθ~本身就不是概率密度函数, p θ ~ \tilde{p_{\theta}} pθ~只是密度函数
  3. 目标:选择一个loss让 ∇ x log ⁡ p ~ θ ( x ) \nabla_x\log \tilde{p}_\theta(x) xlogp~θ(x)尽可能接近 ∇ x t log ⁡ p ( x t ) \nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) xtlogp(xt)

新的问题:不知道数据分布的 score function ∇ x t log ⁡ p ( x t ) \nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) xtlogp(xt)
为了简化公式,下面公式中的 ∇ x log ⁡ p ( x ) \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) xlogp(x)等同于 ∇ x t log ⁡ p ( x t ) \nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) xtlogp(xt)
解决方法:Score Matching

  1. Score Matching:用于估计概率密度函数的梯度(得分函数 score ⁡ ( x ) = ∇ x t log ⁡ p ( x t ) \operatorname{score}(x)=\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) score(x)=xtlogp(xt)),而无需知道密度函数的归一化常数
  2. Score Matching的目标:学习一个模型 q ( x ; θ ) q(x;\theta) q(x;θ),使得模型得分函数 ∇ x log ⁡ q ( x ; θ ) \nabla_x\log q(x;\theta) xlogq(x;θ)与真实分布 p ( x ) p(x) p(x)的得分函数尽可能接近
  3. Score Matching的损失函数: L ( θ ) = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log ⁡ q ( x ; θ ) − ∇ x log ⁡ p ( x ) ∥ 2 ] L(\theta)=\mathbb{E}_{x\sim p(x)}\left[\frac12\left\|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\right\|^2\right] L(θ)=Exp(x)[21xlogq(x;θ)xlogp(x)2],其中的期望差异可以帮助模型更全面地学习到真实分布的特征

接下来,对Score Matching的损失函数 L ( θ ) = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log ⁡ q ( x ; θ ) − ∇ x log ⁡ p ( x ) ∥ 2 ] L(\theta)=\mathbb{E}_{x\sim p(x)}\left[\frac12\left\|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\right\|^2\right] L(θ)=Exp(x)[21xlogq(x;θ)xlogp(x)2]进行推导:

  1. 展开欧几里得范数的平方项: ∥ ∇ x log ⁡ q ( x ; θ ) − ∇ x log ⁡ p ( x ) ∥ 2 = ∥ ∇ x log ⁡ q ( x ; θ ) ∥ 2 − 2 ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) + ∥ ∇ x log ⁡ p ( x ) ∥ 2 \|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\|^2=\|\nabla_x\log q(x;\theta)\|^2-2\nabla_x\log q(x;\theta)\cdot\nabla_x\log p(x)+\|\nabla_x\log p(x)\|^2 xlogq(x;θ)xlogp(x)2=xlogq(x;θ)22xlogq(x;θ)xlogp(x)+xlogp(x)2
  2. 将上式代入原始损失函数中可得 L ( θ ) = E x ∼ p ( x ) [ 1 2 ( ∥ ∇ x log ⁡ q ( x ; θ ) ∥ 2 − 2 ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) + ∥ ∇ x log ⁡ p ( x ) ∥ 2 ) ] L(\theta)=\mathbb{E}_{x\sim p(x)}\left[\frac12\left(\|\nabla_x\log q(x;\theta)\|^2-2\nabla_x\log q(x;\theta)\cdot\nabla_x\log p(x)+\|\nabla_x\log p(x)\|^2\right)\right] L(θ)=Exp(x)[21(xlogq(x;θ)22xlogq(x;θ)xlogp(x)+xlogp(x)2)]
  3. 消除不可计算的项:由于不知道真实分布的 ∇ x log ⁡ p ( x ) \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) xlogp(x),我们无法直接计算 ∥ ∇ x log ⁡ p ( x ) ∥ 2 \|\nabla_x\log p(x)\|^2 xlogp(x)2 ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) \nabla_x\log q(x;\theta)\cdot\nabla_x\log p(x) xlogq(x;θ)xlogp(x)

接下来,笔者给出如何消除不可计算项的过程:

  1. 由于 ∥ ∇ x log ⁡ p ( x ) ∥ 2 \|\nabla_x\log p(x)\|^2 xlogp(x)2 θ \theta θ无关,它仅仅依赖于真实数据分布 p ( x ) p(x) p(x),所以可以直接消掉 ∥ ∇ x log ⁡ p ( x ) ∥ 2 \|\nabla_x\log p(x)\|^2 xlogp(x)2
  2. 对损失函数中的项 ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) \nabla_{x}\log q(x;\theta)\cdot\nabla_{x}\log p(x) xlogq(x;θ)xlogp(x)进行分部积分 ⇒ \Rightarrow ∫ p ( x ) ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) d x = − ∫ p ( x ) ∇ x 2 log ⁡ q ( x ; θ ) d x \int p(x)\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\cdot\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) d\boldsymbol{x}=-\int p(\boldsymbol{x})\nabla_{x}^{2}\log q(\boldsymbol{x};\theta) d\boldsymbol{x} p(x)xlogq(x;θ)xlogp(x)dx=p(x)x2logq(x;θ)dx

∫ p ( x ) ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) d x = − ∫ p ( x ) ∇ x 2 log ⁡ q ( x ; θ ) d x \int p(x)\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\cdot\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) d\boldsymbol{x}=-\int p(\boldsymbol{x})\nabla_{x}^{2}\log q(\boldsymbol{x};\theta) d\boldsymbol{x} p(x)xlogq(x;θ)xlogp(x)dx=p(x)x2logq(x;θ)dx的推导过程:

  1. ∫ p ( x ) ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x p ( x ) p ( x ) d x = ∫ ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x p ( x ) d x \int p(\boldsymbol{x})\nabla_{x}\log q(\boldsymbol{x};\theta)\cdot\frac{\nabla_{\boldsymbol{x}}p(\boldsymbol{x})}{p(\boldsymbol{x})} d\boldsymbol{x}=\int\nabla_{x}\log q(\boldsymbol{x};\theta)\cdot\nabla_{x}p(\boldsymbol{x}) d\boldsymbol{x} p(x)xlogq(x;θ)p(x)xp(x)dx=xlogq(x;θ)xp(x)dx,其中 ∇ x log ⁡ p ( x ) = ∇ x p ( x ) p ( x ) \nabla_x\log p(\boldsymbol{x})=\frac{\nabla_xp(\boldsymbol{x})}{p(\boldsymbol{x})} xlogp(x)=p(x)xp(x)
  2. 分部积分: ∫ u d v = u v ∣ a b − ∫ v d u \int udv=uv|_a^b-\int vdu udv=uvabvdu d v = ∇ x p ( x ) d x dv=\nabla_xp(x)dx dv=xp(x)dx u = ∇ x log ⁡ q ( x ; θ ) u=\nabla_{x}\log q(\boldsymbol{x};\theta) u=xlogq(x;θ) ⇒ \Rightarrow d u = ∇ x 2 l o g q ( x ; θ ) d x 、 du=\nabla^2_xlogq(x;\theta)dx、 du=x2logq(x;θ)dx v = p ( x ) v=p(x) v=p(x)
  3. ∫ ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x p ( x ) d x = ∇ x log ⁡ q ( x ; θ ) ⋅ p ( x ) ∣ x = − ∞ x = ∞ − ∫ p ( x ) ∇ x 2 log ⁡ q ( x ; θ ) d x \int\nabla_x\log q(\boldsymbol{x};\theta)\cdot\nabla_xp(\boldsymbol{x}) d\boldsymbol{x}=\nabla_{x}\log q(\boldsymbol{x};\theta)·p(x)|_{x=-\infty}^{x=\infty}-\int p(\boldsymbol{x})\nabla^2_x\log q(\boldsymbol{x};\theta) d\boldsymbol{x} xlogq(x;θ)xp(x)dx=xlogq(x;θ)p(x)x=x=p(x)x2logq(x;θ)dx
    在这里插入图片描述
    梯度是一个向量,表示在函数在某一点处变化最快的方向和速率;散度是一个标量,表示向量场在某一点处的“扩散”程度。散度为正,表示向量场从该点扩散;散度为负,表示向量场向该点聚集
  4. 当x趋于无穷大时,概率密度通常趋于0,,所以 ∫ ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x p ( x ) d x = − ∫ p ( x ) ∇ x 2 log ⁡ q ( x ; θ ) d x \int\nabla_x\log q(\boldsymbol{x};\theta)\cdot\nabla_xp(\boldsymbol{x})d\boldsymbol{x}=-\int p(\boldsymbol{x})\nabla_x^2\log q(\boldsymbol{x};\theta)d\boldsymbol{x} xlogq(x;θ)xp(x)dx=p(x)x2logq(x;θ)dx
  5. L ( θ ) = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log ⁡ q ( x ; θ ) − ∇ x log ⁡ p ( x ) ∥ 2 ] = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log ⁡ q ( x ; θ ) ∥ 2 + ∇ x 2 log ⁡ q ( x ; θ ) ] \begin{aligned} L(\theta) &=\mathbb{E}_{x\sim p(x)}\left[\frac12\left\|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\right\|^2\right]\\ &=\mathbb{E}_{\boldsymbol{x}\sim p(\boldsymbol{x})}\left[\frac{1}{2}\|\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\|^2+\nabla_{\boldsymbol{x}}^2\log q(\boldsymbol{x};\theta)\right] \end{aligned} L(θ)=Exp(x)[21xlogq(x;θ)xlogp(x)2]=Exp(x)[21xlogq(x;θ)2+x2logq(x;θ)]

至此,我们可以通过损失函数 L ( θ ) L(\theta) L(θ)使 ∇ x log ⁡ q ( x ; θ ) \nabla_x\log q(x;\theta) xlogq(x;θ)接近 ∇ x log ⁡ p ( x ) \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) xlogp(x),进而求出SDE去噪过程中的 d x dx dx
笔者也是刚刚接触SDE,如果文中出现错误,请各位读者指正

参考文献

1、生成扩散模型漫谈(五):一般框架之SDE篇
2、SDE公式推导
3、SDE的底层原理
4、AIGC: SGM (Score-based Generative Model) 笔记
5、Score Matching(得分匹配)
6、从零开始的扩散模型 | 基于分数的生成模型解释

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值