随机微分方程SDE
笔者建议,学完DDPM再来看SDE的作用和推导过程
标准布朗运动
在学习随机微分方程之前,我们先来看一下什么是标准布朗运动
假设有一个一维的直线,有个小人从原点出发,每次随机地选择向左走1格或者向右走1格,且向左走和向右走的两个选项,被选择的概率相等
→
\rightarrow
→用
S
t
S_t
St代表小人离原点的距离,
t
t
t代表代表选择的次数,如果选择的次数越多,那么
S
t
S_t
St将会逐渐服从一个均值为0、方差为
t
t
t的正态分布
布朗运动
W
(
t
)
W(t)
W(t)是期望为0、方差为
t
t
t的正态分布
⇔
\Leftrightarrow
⇔
W
t
∼
N
(
0
,
t
)
W_t\sim \mathcal{N}(0,t)
Wt∼N(0,t)
⇒
\Rightarrow
⇒
W
t
+
Δ
t
−
W
t
∼
N
(
0
,
Δ
t
)
W_{t+\Delta t}-W_t\sim \mathcal{N}(0,\Delta t)
Wt+Δt−Wt∼N(0,Δt),当
Δ
t
→
0
\Delta t\rightarrow 0
Δt→0时,
d
w
=
d
t
ε
dw=\sqrt{dt}\varepsilon
dw=dtε(重参数技巧)
SDE加噪
在DDPM中,扩散过程被划分为固定的T步
⇒
\Rightarrow
⇒DDPM=拆楼+建楼
⇒
\Rightarrow
⇒“拆楼”和“建楼”都被事先划分为了T步,这个划分有着相当大的人为性。事实上,真实的“拆”、“建”过程应该是没有刻意划分的步骤
⇒
\Rightarrow
⇒可以将它们理解为一个在时间上连续的变换过程,可以用随机微分方程(Stochastic Differential Equation,SDE)来描述,即
d
x
=
f
t
(
x
)
d
t
+
g
t
d
w
t
d\boldsymbol{x}=\boldsymbol{f}_t(\boldsymbol{x})dt+g_td\boldsymbol{w_t}
dx=ft(x)dt+gtdwt,其中
f
t
(
x
t
)
f_t(x_t)
ft(xt)是漂移项,描述数据的确定性演化;
g
t
g_t
gt是扩散项,描述的是噪声的扩散程度;
d
w
t
dw_t
dwt是维纳运动(布朗运动)的微小增量,表示随机波动
随机微分方程:
d
x
=
dx=
dx=确定的变化
+
+
+随机的变化,其中随机的变化代表着随机性
随机微分方程描述了系统从
t
t
t时刻到
t
+
Δ
t
t+\Delta t
t+Δt时刻的变化
我们可以将随机微分方程看成是
x
t
+
Δ
t
−
x
t
=
f
t
(
x
t
)
Δ
t
+
g
t
Δ
t
ε
,
ε
∼
N
(
0
,
I
)
\boldsymbol{x}_{t+\Delta t}-\boldsymbol{x}_t=\boldsymbol{f}_t(\boldsymbol{x}_t)\Delta t+g_t\sqrt{\Delta t}\boldsymbol{\varepsilon},\quad\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0},\boldsymbol{I})
xt+Δt−xt=ft(xt)Δt+gtΔtε,ε∼N(0,I)在
Δ
t
→
0
\Delta t\rightarrow 0
Δt→0时的极限
⇒
\Rightarrow
⇒如果建楼要1天,那么拆楼就是
x
x
x从
t
=
0
t=0
t=0到
t
=
1
t=1
t=1时刻的变化
越小的步数
Δ
t
\Delta t
Δt意味着对原始噪声越好的近似,如果
Δ
t
=
0.001
\Delta t=0.001
Δt=0.001,对应着
T
=
1000
T=1000
T=1000;如果
Δ
t
=
0.01
\Delta t=0.01
Δt=0.01,则对应
T
=
100
T=100
T=100(总时间步数
T
T
T是模拟的总时间跨度被步长
Δ
t
\Delta t
Δt划分的次数
T
=
t
max
Δ
t
T=\frac{t_{\max}}{\Delta t}
T=Δttmax)
⇒
\Rightarrow
⇒引入SDE的本质好处是“将理论分析和代码实现分离开来”
DDPM的加噪过程本质上是一个SDE,而SDE本质上描述的是微小时间变化下系统状态的变化
- DDPM的加噪: x t + 1 = 1 − β t x t + β t ϵ x_{t+1}=\sqrt{1-\beta_t}x_t+\sqrt{\beta_t}\epsilon xt+1=1−βtxt+βtϵ
- SDE的加噪: d x = f t ( x ) d t + g t d w t d\boldsymbol{x}=\boldsymbol{f}_t(\boldsymbol{x})dt+g_td\boldsymbol{w_t} dx=ft(x)dt+gtdwt
在这里,笔者介绍一下将DDPM加噪公式映射到SDE加噪公式的推导过程:
- 重写DDPM加噪公式:
x
t
+
1
−
x
t
=
(
1
−
β
t
−
1
)
x
t
+
β
t
ϵ
x_{t+1}-x_t=(\sqrt{1-\beta_t}-1)x_t+\sqrt{\beta_t}\epsilon
xt+1−xt=(1−βt−1)xt+βtϵ
⇒
\Rightarrow
⇒
1
−
β
t
≈
1
−
β
t
2
\sqrt{1-\beta_t}\approx1-\frac{\beta_t}2
1−βt≈1−2βt
⇒
\Rightarrow
⇒将DDPM加噪公式重新表示为一个确定项和随机噪声项的和:
x
t
+
1
−
x
t
≈
−
β
t
2
x
t
+
β
t
ϵ
x_{t+1}-x_{t}\approx-\frac{\beta_{t}}{2}x_{t}+\sqrt{\beta_{t}}\epsilon
xt+1−xt≈−2βtxt+βtϵ
在这里,使用泰勒展开得到 1 − β t ≈ 1 − β t 2 \sqrt{1-\beta_t}\approx1-\frac{\beta_t}2 1−βt≈1−2βt
先来介绍一下泰勒展开:如果 f ( x ) f(x) f(x)在 x = a x=a x=a处是可微的,则它的泰勒展开可以写为 f ( x ) ≈ f ( a ) + f ′ ( a ) ( x − a ) + f ′ ′ ( a ) 2 ! ( x − a ) 2 + … f(x)\approx f(a)+f'(a)(x-a)+\frac{f''(a)}{2!}(x-a)^2+\ldots f(x)≈f(a)+f′(a)(x−a)+2!f′′(a)(x−a)2+…,其中 f ′ ( a ) f'(a) f′(a)和 f ′ ′ ( a ) f''(a) f′′(a)分别是 f ( x ) f(x) f(x)在 a a a处的一阶导数和二阶导数;在泰勒展开中,若函数依赖多个变量,需要对每个变量分别进行展开
f ( β t ) = 1 − β t f(\beta_t)=\sqrt{1-\beta_t} f(βt)=1−βt在 β t = 0 \beta_t=0 βt=0处展开 ⇒ \Rightarrow ⇒零阶项: f ( 0 ) = 1 − 0 = 1 f(0)=\sqrt{1-0}=1 f(0)=1−0=1;一阶导数: f ′ ( β t ) = d d β t 1 − β t = − 1 2 1 − β t f'(\beta_t)=\frac{d}{d\beta_t}\sqrt{1-\beta_t}=\frac{-1}{2\sqrt{1-\beta_t}} f′(βt)=dβtd1−βt=21−βt−1,在 β t = 0 \beta_t=0 βt=0处 f ′ ( 0 ) = − 1 2 1 − 0 = − 1 2 f'(0)=\frac{-1}{2\sqrt{1-0}}=-\frac{1}{2} f′(0)=21−0−1=−21 ⇒ \Rightarrow ⇒ f ( β t ) ≈ f ( 0 ) + f ′ ( 0 ) β t f(\beta_t)\approx f(0)+f'(0)\beta_t f(βt)≈f(0)+f′(0)βt ⇒ \Rightarrow ⇒ 1 − β t ≈ 1 − 1 2 β t \sqrt{1-\beta_t}\approx1-\frac{1}{2}\beta_t 1−βt≈1−21βt - 引入 Δ t \Delta t Δt: Δ t \Delta t Δt只是在数学上引入的时间增量,而 β t \beta_t βt在离散模型中的定义是独立于 Δ t \Delta t Δt的,将DDPM离散的加噪过程转换为连续时间的随机微分方程描述: x t + Δ t − x t ≈ − β t 2 x t Δ t + β t Δ t ϵ x_{t+\Delta t}-x_t\approx-\frac{\beta_t}{2}x_t\Delta t+\sqrt{\beta_t\Delta t}\epsilon xt+Δt−xt≈−2βtxtΔt+βtΔtϵ ⇒ \Rightarrow ⇒ d x = − 1 2 β t x t d t + β t d w dx=-\frac{1}{2}\beta_tx_tdt+\sqrt{\beta_t}dw dx=−21βtxtdt+βtdw
- SDE的形式:漂移项
f
t
(
x
t
)
=
−
β
t
2
x
t
f_t(x_t)=-\frac{\beta_t}2x_t
ft(xt)=−2βtxt,扩散系数
g
t
=
β
t
g_{t}=\sqrt{\beta_{t}}
gt=βt
左侧是数据分布,右侧是正态分布,t是连续时间
SDE去噪
SDE去噪的目标是求
p
(
x
t
∣
x
t
+
Δ
t
)
p(x_t|x_{t+\Delta t})
p(xt∣xt+Δt)
已知:
x
t
+
Δ
t
x_{t+\Delta t}
xt+Δt和前向SDE过程
p
(
x
t
+
Δ
t
∣
x
t
)
p(x_{t+\Delta t}|x_t)
p(xt+Δt∣xt)
⇒
\Rightarrow
⇒贝叶斯公式:
p
(
x
t
∣
x
t
+
Δ
t
)
=
p
(
x
t
+
Δ
t
∣
x
t
)
p
(
x
t
)
p
(
x
t
+
Δ
t
)
p(x_t|x_{t+\Delta t})=\frac{p(x_{t+\Delta t}|x_t)p(x_t)}{p(x_{t+\Delta t})}
p(xt∣xt+Δt)=p(xt+Δt)p(xt+Δt∣xt)p(xt)
为了简化问题,尽可能使
p
(
x
t
∣
x
t
+
Δ
t
)
p(x_t|x_{t+\Delta t})
p(xt∣xt+Δt)的分布满足正态分布
- x t + Δ t = x t + f t ( x t ) Δ t + g t Δ t ϵ x_{t+\Delta t}=x_t+f_t(x_t)\Delta t+g_t\sqrt{\Delta t}\epsilon xt+Δt=xt+ft(xt)Δt+gtΔtϵ ⇒ \Rightarrow ⇒根据重参数可得: x t + Δ t ∼ N ( x t + f t ( x t ) Δ t , g t 2 Δ t ) x_{t+\Delta t}\sim\mathcal{N}(x_t+f_t(x_t)\Delta t,g_t^2\Delta t) xt+Δt∼N(xt+ft(xt)Δt,gt2Δt)
- 正态分布的概率密度函数: f ( x ) = 1 σ 2 π e x p ( − ( x − μ ) 2 2 σ 2 ) f(x)=\frac1{\sigma\sqrt{2\pi}}exp(-\frac{(x-\mu)^2}{2\sigma^2}) f(x)=σ2π1exp(−2σ2(x−μ)2)
- p ( x t ∣ x t + Δ t ) = e x p ( − ( x t + Δ t − x t − f t ( x t ) Δ t ) 2 2 g t 2 Δ t + l o g p ( x t ) − l o g p ( x t + Δ t ) ) ( 1 ) \begin{aligned} p(x_t|x_{t+\Delta t}) &=exp(-\frac{(x_{t+\Delta t}-x_t-f_t(x_t)\Delta t)^2}{2g_t^2\Delta t}+logp(x_t)-logp(x_{t+\Delta t}))&&(1)\\ \end{aligned} p(xt∣xt+Δt)=exp(−2gt2Δt(xt+Δt−xt−ft(xt)Δt)2+logp(xt)−logp(xt+Δt))(1)
- 在
x
t
x_t
xt处泰勒展开
l
o
g
p
(
x
t
+
Δ
t
)
logp(x_{t+\Delta t})
logp(xt+Δt):
log
p
(
x
t
+
Δ
t
,
t
+
Δ
t
)
≈
log
p
(
x
t
,
t
)
+
(
x
t
+
Δ
t
−
x
t
)
∇
x
log
p
(
x
t
,
t
)
⏟
状杰变化的影响
+
Δ
t
⋅
∇
t
log
p
(
x
t
,
t
)
⏟
时间推移的影响
\log p(x_{t+\Delta t},t+\Delta t)\approx\log p(x_t,t)+\underbrace{(x_{t+\Delta t}-x_t)\nabla_x\log p(x_t,t)}_\text{状杰变化的影响}+\underbrace{\Delta t\cdot\nabla_t\log p(x_t,t)}_\text{时间推移的影响}
logp(xt+Δt,t+Δt)≈logp(xt,t)+状杰变化的影响
(xt+Δt−xt)∇xlogp(xt,t)+时间推移的影响
Δt⋅∇tlogp(xt,t),其中状态梯度
∇
x
log
p
(
x
t
)
\nabla_{x}\log p(x_{t})
∇xlogp(xt)描述概率密度在空间中的“漂移”趋势(例如粒子倾向于从高密度区向低密度区移动),时间梯度
∇
t
log
p
(
x
t
)
\nabla_{t}\log p(x_{t})
∇tlogp(xt)描述概率密度随时间的整体衰减或增长
在这里笔者介绍一下为什么会多出一项 ∇ t log p ( x t ) \nabla_{t}\log p(x_{t}) ∇tlogp(xt): p ( x t ) p(x_t) p(xt)实际上是“ t t t时刻随机变量等于 x t x_t xt的概率密度”, p ( x t + Δ t ) p(x_{t+\Delta t}) p(xt+Δt)实际上是“ t + Δ t t+\Delta t t+Δt时刻随机变量等于 x t + Δ t x_{t+\Delta t} xt+Δt的概率密度”,即 p ( x t ) p(x_t) p(xt)实际上同时是时间 t t t和状态变量 x t x_t xt的函数
时间梯度项 ∇ t log p ( x t ) \nabla_{t}\log p(x_{t}) ∇tlogp(xt)的必要性: ∇ t log p ( x t ) \nabla_{t}\log p(x_{t}) ∇tlogp(xt)描述了概率密度的动态演化,即使状态 x t x_t xt不变,概率密度 p ( x t ) p(x_t) p(xt)也会随时间 t t t变化 - 当 Δ t → 0 \Delta t\rightarrow 0 Δt→0时, Δ 2 t = 0 \Delta^2 t=0 Δ2t=0
-
(
1
)
=
p
(
x
t
∣
x
t
+
Δ
t
)
=
e
x
p
(
−
(
x
t
+
Δ
t
−
x
t
−
(
f
t
(
x
t
)
−
g
t
2
∇
x
l
o
g
p
(
x
t
)
)
Δ
t
)
2
2
g
t
2
Δ
t
)
=
e
x
p
(
−
(
x
t
−
(
x
t
+
Δ
t
−
(
f
t
(
x
t
)
−
g
t
2
∇
x
l
o
g
p
(
x
t
)
)
Δ
t
)
)
2
2
g
t
2
Δ
t
)
(
2
)
\begin{aligned} (1)=p(x_t|x_{t+\Delta t}) &=exp(-\frac{(x_{t+\Delta t}-x_t-(f_t(x_t)-g_t^2\nabla xlogp(x_t))\Delta t)^2}{2g_t^2\Delta t})\\ &=exp(-\frac{(x_t-(x_{t+\Delta t}-(f_{t}(x_t)-g_t^2\nabla xlogp(x_t))\Delta t))^2}{2g_t^2\Delta t})&&(2) \end{aligned}
(1)=p(xt∣xt+Δt)=exp(−2gt2Δt(xt+Δt−xt−(ft(xt)−gt2∇xlogp(xt))Δt)2)=exp(−2gt2Δt(xt−(xt+Δt−(ft(xt)−gt2∇xlogp(xt))Δt))2)(2)
将 x t + Δ t − x t x_{t+\Delta t}-x_t xt+Δt−xt和 ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) Δ t (f_t(x_t)-g_t^2\nabla xlogp(x_t))\Delta t (ft(xt)−gt2∇xlogp(xt))Δt分别看作一个整体, ( x t + Δ t − x t − ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) Δ t ) 2 2 g t 2 Δ t \frac{(x_{t+\Delta t}-x_t-(f_t(x_t)-g_t^2\nabla xlogp(x_t))\Delta t)^2}{2g_t^2\Delta t} 2gt2Δt(xt+Δt−xt−(ft(xt)−gt2∇xlogp(xt))Δt)2开方后可以得到和 ( x t + Δ t − x t − f t ( x t ) Δ t ) 2 2 g t 2 Δ t + l o g p ( x t ) − l o g p ( x t + Δ t ) \frac{(x_{t+\Delta t}-x_t-f_t(x_t)\Delta t)^2}{2g_t^2\Delta t}+logp(x_t)-logp(x_{t+\Delta t}) 2gt2Δt(xt+Δt−xt−ft(xt)Δt)2+logp(xt)−logp(xt+Δt)相同的结果 - 当 Δ t → 0 \Delta t\rightarrow 0 Δt→0时, t + Δ t → t t+\Delta t\rightarrow t t+Δt→t
- ( 2 ) = e x p ( − ( x t − ( x t + Δ t − ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t ) ) 2 2 g t + Δ t 2 Δ t ) \begin{aligned} (2)=exp(-\frac{(x_t-(x_{t+\Delta t}-(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t))^2}{2g_{t+\Delta t}^2\Delta t}) \end{aligned} (2)=exp(−2gt+Δt2Δt(xt−(xt+Δt−(ft+Δt(xt+Δt)−gt+Δt2∇xlogp(xt+Δt))Δt))2)
- 均值 μ = x t + Δ t − ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t \mu=x_{t+\Delta t}-(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t μ=xt+Δt−(ft+Δt(xt+Δt)−gt+Δt2∇xlogp(xt+Δt))Δt、方差 g t + Δ t 2 Δ t g_{t+\Delta t}^2\Delta t gt+Δt2Δt
- x t = x t + Δ t − ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t + g t + Δ t Δ t ϵ x_t=x_{t+\Delta t}-(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t+g_{t+\Delta t}\sqrt{\Delta t}\epsilon xt=xt+Δt−(ft+Δt(xt+Δt)−gt+Δt2∇xlogp(xt+Δt))Δt+gt+ΔtΔtϵ
- 当 Δ t → 0 \Delta t\rightarrow 0 Δt→0时, Δ t → d t \Delta t\rightarrow dt Δt→dt
- x t + Δ t − x t = ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t + g t + Δ t Δ t ϵ x_{t+\Delta t}-x_t=(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t+g_{t+\Delta t}\sqrt{\Delta t}\epsilon xt+Δt−xt=(ft+Δt(xt+Δt)−gt+Δt2∇xlogp(xt+Δt))Δt+gt+ΔtΔtϵ
- d x t = ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) d t + g t d t ϵ = ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) d t + g t d w ˉ \begin{aligned} dx_t=(f_{t}(x_{t})-g_{t}^2\nabla xlogp(x_{t}))dt+g_{t}\sqrt{dt}\epsilon=(f_{t}(x_{t})-g_{t}^2\nabla xlogp(x_{t}))dt+g_{t}d\bar{w} \end{aligned} dxt=(ft(xt)−gt2∇xlogp(xt))dt+gtdtϵ=(ft(xt)−gt2∇xlogp(xt))dt+gtdwˉ
Score Matching
Score Matching 是一种用于估计概率密度函数
p
(
x
)
p(x)
p(x) 的方法,核心方法是直接估计概率密度的梯度(称为score)而非密度本身
通过最小化模型估计的score与真实score之间的差异,可以简介学习数据的分布
我们先来看一下
∇
x
t
log
p
(
x
t
)
\nabla_{x_t}\log p(\boldsymbol{x}_t)
∇xtlogp(xt)的含义:对数概率密度函数
log
p
(
x
t
)
\log p(x_{t})
logp(xt)关于
x
t
x_t
xt的梯度方向指向概率密度增加最快的方向、梯度大小反映密度变化的速率
⇒
\Rightarrow
⇒使用一个
θ
\theta
θ参数化的概率分布
p
θ
p_{\theta}
pθ模拟
p
p
p,通过学习参数
θ
\theta
θ使
p
θ
p_{\theta}
pθ接近
p
p
p
我们可以将
p
θ
p_{\theta}
pθ看成是由两部分组成的,分别是表示密度的函数
p
θ
~
\tilde{p_{\theta}}
pθ~、归一化因子
Z
θ
Z_{\theta}
Zθ
⇒
\Rightarrow
⇒
p
θ
(
x
)
=
p
~
θ
(
x
)
Z
θ
=
p
~
θ
(
x
)
∫
x
∈
X
p
~
θ
(
x
)
d
x
p_\theta(x)=\frac{\tilde{p}_\theta(x)}{Z_\theta}=\frac{\tilde{p}_\theta(x)}{\int_{x\in X}\tilde{p}_\theta(x)dx}
pθ(x)=Zθp~θ(x)=∫x∈Xp~θ(x)dxp~θ(x),其中未归一化的概率密度函数
p
θ
~
\tilde{p_{\theta}}
pθ~给出某个数据点
x
x
x相对于其他数据点的可能性大小,但并不能给出直接用于表示
x
x
x发生的真实概率
目前,使用极大似然估计求解
θ
\theta
θ的问题:不知道归一化因子
Z
θ
Z_{\theta}
Zθ的值
解决方法:
- 引入得分函数(score function):概率密度函数的梯度 ∇ x log p θ ( x ) \nabla_x\log p_\theta(x) ∇xlogpθ(x)
- 将 p θ ( x ) p_{\theta}(x) pθ(x)通过 l o g log log拆分成两项 ∇ x log p ~ θ ( x ) − ∇ x log Z θ \nabla_{x}\log\tilde{p}_{\theta}(x)-\nabla_{x}\log Z_{\theta} ∇xlogp~θ(x)−∇xlogZθ ⇒ \Rightarrow ⇒由于求解的是 x x x的梯度,所以可以直接消掉 ∇ x log Z θ \nabla_{x}\log Z_{\theta} ∇xlogZθ,因为 ∇ x log Z θ \nabla_{x}\log Z_{\theta} ∇xlogZθ与 x x x无关;同时 p θ ~ \tilde{p_{\theta}} pθ~不受“概率分布”的约束,可以使用神经网络作为 p θ ~ \tilde{p_{\theta}} pθ~,因为 p θ ~ \tilde{p_{\theta}} pθ~本身就不是概率密度函数, p θ ~ \tilde{p_{\theta}} pθ~只是密度函数
- 目标:选择一个loss让 ∇ x log p ~ θ ( x ) \nabla_x\log \tilde{p}_\theta(x) ∇xlogp~θ(x)尽可能接近 ∇ x t log p ( x t ) \nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) ∇xtlogp(xt)
新的问题:不知道数据分布的 score function
∇
x
t
log
p
(
x
t
)
\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t)
∇xtlogp(xt)
为了简化公式,下面公式中的
∇
x
log
p
(
x
)
\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x})
∇xlogp(x)等同于
∇
x
t
log
p
(
x
t
)
\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t)
∇xtlogp(xt)
解决方法:Score Matching
- Score Matching:用于估计概率密度函数的梯度(得分函数 score ( x ) = ∇ x t log p ( x t ) \operatorname{score}(x)=\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) score(x)=∇xtlogp(xt)),而无需知道密度函数的归一化常数
- Score Matching的目标:学习一个模型 q ( x ; θ ) q(x;\theta) q(x;θ),使得模型得分函数 ∇ x log q ( x ; θ ) \nabla_x\log q(x;\theta) ∇xlogq(x;θ)与真实分布 p ( x ) p(x) p(x)的得分函数尽可能接近
- Score Matching的损失函数: L ( θ ) = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log q ( x ; θ ) − ∇ x log p ( x ) ∥ 2 ] L(\theta)=\mathbb{E}_{x\sim p(x)}\left[\frac12\left\|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\right\|^2\right] L(θ)=Ex∼p(x)[21∥∇xlogq(x;θ)−∇xlogp(x)∥2],其中的期望差异可以帮助模型更全面地学习到真实分布的特征
接下来,对Score Matching的损失函数 L ( θ ) = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log q ( x ; θ ) − ∇ x log p ( x ) ∥ 2 ] L(\theta)=\mathbb{E}_{x\sim p(x)}\left[\frac12\left\|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\right\|^2\right] L(θ)=Ex∼p(x)[21∥∇xlogq(x;θ)−∇xlogp(x)∥2]进行推导:
- 展开欧几里得范数的平方项: ∥ ∇ x log q ( x ; θ ) − ∇ x log p ( x ) ∥ 2 = ∥ ∇ x log q ( x ; θ ) ∥ 2 − 2 ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) + ∥ ∇ x log p ( x ) ∥ 2 \|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\|^2=\|\nabla_x\log q(x;\theta)\|^2-2\nabla_x\log q(x;\theta)\cdot\nabla_x\log p(x)+\|\nabla_x\log p(x)\|^2 ∥∇xlogq(x;θ)−∇xlogp(x)∥2=∥∇xlogq(x;θ)∥2−2∇xlogq(x;θ)⋅∇xlogp(x)+∥∇xlogp(x)∥2
- 将上式代入原始损失函数中可得 L ( θ ) = E x ∼ p ( x ) [ 1 2 ( ∥ ∇ x log q ( x ; θ ) ∥ 2 − 2 ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) + ∥ ∇ x log p ( x ) ∥ 2 ) ] L(\theta)=\mathbb{E}_{x\sim p(x)}\left[\frac12\left(\|\nabla_x\log q(x;\theta)\|^2-2\nabla_x\log q(x;\theta)\cdot\nabla_x\log p(x)+\|\nabla_x\log p(x)\|^2\right)\right] L(θ)=Ex∼p(x)[21(∥∇xlogq(x;θ)∥2−2∇xlogq(x;θ)⋅∇xlogp(x)+∥∇xlogp(x)∥2)]
- 消除不可计算的项:由于不知道真实分布的 ∇ x log p ( x ) \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) ∇xlogp(x),我们无法直接计算 ∥ ∇ x log p ( x ) ∥ 2 \|\nabla_x\log p(x)\|^2 ∥∇xlogp(x)∥2和 ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) \nabla_x\log q(x;\theta)\cdot\nabla_x\log p(x) ∇xlogq(x;θ)⋅∇xlogp(x)
接下来,笔者给出如何消除不可计算项的过程:
- 由于 ∥ ∇ x log p ( x ) ∥ 2 \|\nabla_x\log p(x)\|^2 ∥∇xlogp(x)∥2和 θ \theta θ无关,它仅仅依赖于真实数据分布 p ( x ) p(x) p(x),所以可以直接消掉 ∥ ∇ x log p ( x ) ∥ 2 \|\nabla_x\log p(x)\|^2 ∥∇xlogp(x)∥2
- 对损失函数中的项 ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) \nabla_{x}\log q(x;\theta)\cdot\nabla_{x}\log p(x) ∇xlogq(x;θ)⋅∇xlogp(x)进行分部积分 ⇒ \Rightarrow ⇒ ∫ p ( x ) ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) d x = − ∫ p ( x ) ∇ x 2 log q ( x ; θ ) d x \int p(x)\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\cdot\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) d\boldsymbol{x}=-\int p(\boldsymbol{x})\nabla_{x}^{2}\log q(\boldsymbol{x};\theta) d\boldsymbol{x} ∫p(x)∇xlogq(x;θ)⋅∇xlogp(x)dx=−∫p(x)∇x2logq(x;θ)dx
∫ p ( x ) ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) d x = − ∫ p ( x ) ∇ x 2 log q ( x ; θ ) d x \int p(x)\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\cdot\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) d\boldsymbol{x}=-\int p(\boldsymbol{x})\nabla_{x}^{2}\log q(\boldsymbol{x};\theta) d\boldsymbol{x} ∫p(x)∇xlogq(x;θ)⋅∇xlogp(x)dx=−∫p(x)∇x2logq(x;θ)dx的推导过程:
- ∫ p ( x ) ∇ x log q ( x ; θ ) ⋅ ∇ x p ( x ) p ( x ) d x = ∫ ∇ x log q ( x ; θ ) ⋅ ∇ x p ( x ) d x \int p(\boldsymbol{x})\nabla_{x}\log q(\boldsymbol{x};\theta)\cdot\frac{\nabla_{\boldsymbol{x}}p(\boldsymbol{x})}{p(\boldsymbol{x})} d\boldsymbol{x}=\int\nabla_{x}\log q(\boldsymbol{x};\theta)\cdot\nabla_{x}p(\boldsymbol{x}) d\boldsymbol{x} ∫p(x)∇xlogq(x;θ)⋅p(x)∇xp(x)dx=∫∇xlogq(x;θ)⋅∇xp(x)dx,其中 ∇ x log p ( x ) = ∇ x p ( x ) p ( x ) \nabla_x\log p(\boldsymbol{x})=\frac{\nabla_xp(\boldsymbol{x})}{p(\boldsymbol{x})} ∇xlogp(x)=p(x)∇xp(x)
- 分部积分: ∫ u d v = u v ∣ a b − ∫ v d u \int udv=uv|_a^b-\int vdu ∫udv=uv∣ab−∫vdu、 d v = ∇ x p ( x ) d x dv=\nabla_xp(x)dx dv=∇xp(x)dx, u = ∇ x log q ( x ; θ ) u=\nabla_{x}\log q(\boldsymbol{x};\theta) u=∇xlogq(x;θ) ⇒ \Rightarrow ⇒ d u = ∇ x 2 l o g q ( x ; θ ) d x 、 du=\nabla^2_xlogq(x;\theta)dx、 du=∇x2logq(x;θ)dx、 v = p ( x ) v=p(x) v=p(x)
-
∫
∇
x
log
q
(
x
;
θ
)
⋅
∇
x
p
(
x
)
d
x
=
∇
x
log
q
(
x
;
θ
)
⋅
p
(
x
)
∣
x
=
−
∞
x
=
∞
−
∫
p
(
x
)
∇
x
2
log
q
(
x
;
θ
)
d
x
\int\nabla_x\log q(\boldsymbol{x};\theta)\cdot\nabla_xp(\boldsymbol{x}) d\boldsymbol{x}=\nabla_{x}\log q(\boldsymbol{x};\theta)·p(x)|_{x=-\infty}^{x=\infty}-\int p(\boldsymbol{x})\nabla^2_x\log q(\boldsymbol{x};\theta) d\boldsymbol{x}
∫∇xlogq(x;θ)⋅∇xp(x)dx=∇xlogq(x;θ)⋅p(x)∣x=−∞x=∞−∫p(x)∇x2logq(x;θ)dx

梯度是一个向量,表示在函数在某一点处变化最快的方向和速率;散度是一个标量,表示向量场在某一点处的“扩散”程度。散度为正,表示向量场从该点扩散;散度为负,表示向量场向该点聚集 - 当x趋于无穷大时,概率密度通常趋于0,,所以 ∫ ∇ x log q ( x ; θ ) ⋅ ∇ x p ( x ) d x = − ∫ p ( x ) ∇ x 2 log q ( x ; θ ) d x \int\nabla_x\log q(\boldsymbol{x};\theta)\cdot\nabla_xp(\boldsymbol{x})d\boldsymbol{x}=-\int p(\boldsymbol{x})\nabla_x^2\log q(\boldsymbol{x};\theta)d\boldsymbol{x} ∫∇xlogq(x;θ)⋅∇xp(x)dx=−∫p(x)∇x2logq(x;θ)dx
- L ( θ ) = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log q ( x ; θ ) − ∇ x log p ( x ) ∥ 2 ] = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log q ( x ; θ ) ∥ 2 + ∇ x 2 log q ( x ; θ ) ] \begin{aligned} L(\theta) &=\mathbb{E}_{x\sim p(x)}\left[\frac12\left\|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\right\|^2\right]\\ &=\mathbb{E}_{\boldsymbol{x}\sim p(\boldsymbol{x})}\left[\frac{1}{2}\|\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\|^2+\nabla_{\boldsymbol{x}}^2\log q(\boldsymbol{x};\theta)\right] \end{aligned} L(θ)=Ex∼p(x)[21∥∇xlogq(x;θ)−∇xlogp(x)∥2]=Ex∼p(x)[21∥∇xlogq(x;θ)∥2+∇x2logq(x;θ)]
至此,我们可以通过损失函数
L
(
θ
)
L(\theta)
L(θ)使
∇
x
log
q
(
x
;
θ
)
\nabla_x\log q(x;\theta)
∇xlogq(x;θ)接近
∇
x
log
p
(
x
)
\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x})
∇xlogp(x),进而求出SDE去噪过程中的
d
x
dx
dx
笔者也是刚刚接触SDE,如果文中出现错误,请各位读者指正
参考文献
1、生成扩散模型漫谈(五):一般框架之SDE篇
2、SDE公式推导
3、SDE的底层原理
4、AIGC: SGM (Score-based Generative Model) 笔记
5、Score Matching(得分匹配)
6、从零开始的扩散模型 | 基于分数的生成模型解释


2万+

被折叠的 条评论
为什么被折叠?



