筆記來源 [DeepBayes2019]: Day 5, Lecture 3. Langevin dynamics for sampling and global optimization 前半小時. 非常精彩!
粒子 x follow Langevin dynamics 的話: x−x′=−∇U(x′)dt+N(0,σ2dt)
∂∂tpt(x)=∇pt(x)T∇U(x)+pt(x)div∇U(x)+12σ2∇2pt(x)
∂∂tpt(x)=div(pt(x)∇U(x))+12σ2∇2pt(x)
而如果將要採樣的目標分佈 p(x) 設定成這種 stationary 分佈的話.
由於是 stationary 表示就算繼續 follow Langevin dynamics 讓粒子 x 移動 (更新), 更新後的值仍然滿足目標分佈 p(x), 因此達到採樣效果!
而這也是 Denoising Diffusion Probabilistic Models (DDPM) 做採樣時的方法.
接著詳細記錄 Langevin dynamics, Fokker-Planck equation 推導, 以及 stationary 分佈和採樣方法.
如果讀者知道 Continuity equation 的話, 應該會發現與 F-P equation 非常相似. 它們的關聯可以參考 “Flow Matching for Generative Modeling” 論文的 Appendix D.
Langevin Dynamics
從 Langevin dynamics 出發, 考慮如下的 Stochastic Differential Equation (SDE) 其中 X(t) 是 random process:
dX(t)=−∇U(X(t))dt⏟Force+σdBt⏟random fluctuation
對它做離散逼近:
Xt+1−Xt=−dt∇U(Xt)+σ√dtN(0,I)
下圖顯示粒子使用 (1) 的移動軌跡 (來源):
這個 SDE 其實跟 Machine Learning 的 Gradient Descent 很有關聯, 改寫一下:
Wt+1−Wt=−ε∇L(Wt)+σ√εN(0,I)
會發現就是 gradient descent 公式多一個 random 項.
Fokker-Planck Equation
Fokker-Planck equation 描述了如果 partical x 的移動遵從 Langevin dynamics, 則 density 隨著時間的變化, i.e. ∂∂tpt(x), 可以被描述出來
∂∂tpt(x)=∇pt(x)T∇U(x)+pt(x)div∇U(x)+12σ2∇2pt(x)
隨著時間似乎會達到 stationary distribution.
詳細推導請看最後一段 Appendix. 另外補充 F-P equation 與 Continuity equation 的關係可以參考論文的 Appendix D.
Stationary distribution
要怎麼找到這樣的 distribution? 流程就是先假設有 stationary density 且令其為 Gibbs distribution, pG(x), 的形式
pG(x)=1Zexp(−U(x)T),whereZ=∫exp(−U(x)T)dx
我們最終得到 T=σ2/2. 說明了 stationary distribution 的長相為 T=σ2/2 的 Gibbs distribution.
[T=σ2/2 的推導]
我們會利用到 divergence, ∇⋅, 具有 linearity 性質.
把 pG(x) 帶入到 Fokker-Planck equation 其中 normalization term Z 可以忽略 (會被除掉)
0=(∇exp(−U(x)/T))T∇U(x)+exp(−U(x)/T)∇⋅∇U(x)+12σ2∇⋅∇exp(−U(x)/T)=∇⋅[exp(−U(x)/T)∇U(x)]+12σ2∇⋅∇exp(−U(x)/T)=∇⋅[exp(−U(x)/T)∇U(x)−σ22Texp(−U(x)/T)∇U(x)]=∇⋅[(1−σ22T)exp(−U(x)/T)∇U(x)]=0
因此 =0 只有可能中括號內 =0, 所以
(1−σ22T)exp(−U(x)/T)∇U(x)=0⟹1−σ22T=0⟹T=σ22
對目標分布做採樣
因此我們知道, 如果使用 Langevin equation 做擴散的話且希望它最終達到 target distribution p(x), 只要定義 U(x)=−logp(x) and σ=√2, i.e. T=1. 則根據 pG(x) 的定義, pG(x) 會等於我們的 target distribution p(x), 又已知 pG(x) 為 stationary distribution, 所以用 Langevin equation 擴散隨著時間到最後等同於從 pG(x)=p(x) 採樣.
所以可以這麼做 Langevin dynamics sampling (只利用到 score function 做 sampling) (圖片來源) (圖片來源)
Appendix: Derivation of the Fokker-Planck Equation
重複一下 Langevin dynamics
dX(t)=−∇U(X(t))dt+σdBt
x−x′=−∇U(x′)dt+N(0,σ2dt)⟹x∼N(x′−∇U(x′)dt,σ2dt):=q(x|x′)
所以 pt(x) 我們可以這麼寫:
pt(x)=∫pt−dt(x′)q(x|x′)dx′
q(x|x′)=1(2πσ2dt)n/2exp(−(x′−x−∇U(x′)dt)22σ2dt)
y≜x′−x−∇U(x′)dt:=f(x′)
pt(x)=∫pt−dt(x′(y))N(y|0,σ2dt⋅I)|∂x′∂y|dy
先處理比較好做的 |∂x′/∂y| (謝謝某位讀者來信討論, 幫助推導)
根據定義 y=x′−x−∇U(x′)dt 且我們有 (I−A)−1=I+A+A2+A3+... 的公式 [ref], 先計算 ∂y/∂x′ 得到:
∂y∂x′=I−H(x′)dt
Hij=∂2U∂xi∂xj
∂x′∂y=(∂y∂x′)−1=(I−H(x′)dt)−1=I+H(x′)dt+o(dt)
|I+H(x′)dt|=|1+∂2U∂x21dt∂2U∂x1∂x2dt∂2U∂x2∂x1dt1+∂2U∂x22dt|=(1+∂2U∂x21dt)(1+∂2U∂x22dt)+o(dt)=(1+∂2U∂x21dt)+(1+∂2U∂x21dt)∂2U∂x22dt+o(dt)=1+∂2U∂x21dt+∂2U∂x22dt+o(dt)=1+div∇U(x′)dt+o(dt)
|∂x′∂y|=1+div∇U(x′)dt+o(dt)
y≜x′−x−∇U(x′)dt=x′−x−(∇U(x)+(x′−x)∂∇U(x)∂x+o(x′−x))dt
=(I−∂∇U(x)∂xdt)x′−x−∇U(x)dt+x∂∇U(x)∂xdt+o(dt)⟹x′=(I−∂∇U(x)∂xdt)−1(y+x+∇U(x)dt−x∂∇U(x)∂xdt+o(dt))
(I−∂∇U(x)∂xdt)−1=I+∂∇U(x)∂xdt+o(dt)
x′=(I+∂∇U(x)∂xdt+o(dt))(y+x+∇U(x)dt−∂∇U(x)∂xxdt+o(dt))=y+x+∇U(x)dt−∂∇U(x)∂xxdt+∂∇U(x)∂xydt+∂∇U(x)∂xxdt+o(dt)=y+x+∇U(x)dt+∂∇U(x)∂xydt+o(dt)
根據 y 和 x−x′ 的定義 (8) 和 (4):
y=x′−x−∇U(x′)dt=∇U(x′)dt−N(0,σ2dt)−∇U(x′)dt=−N(0,σ2dt)
ydt=−N(0,σ2dt)dt=−dt√dtN(0,σ2)=o(dt)
x′=y+x+∇U(x)dt+∂∇U(x)∂xo(dt)+o(dt)⟹x′(y)=x+y+∇U(x)dt+o(dt)
pt(x)=∫pt−dt(x′(y))N(y|0,σ2dt⋅I)|∂x′∂y|dy=(1+div∇U(x)dt)Ey[pt−dt(x+y+∇U(x)dt)],wherey∼N(0,σ2dt⋅I)
(0th order): pt(x)
(1st order):
∇pt(x)T(y+∇U(x)dt)+∂∂tpt(x)(−dt)
12(y+∇U(x)dt)T∂2pt(x)∂x2(y+∇U(x)dt)
取 Ey:
(0th order): 與 y 無關, 是 constant:
Ey[pt(x)]=pt(x)
(1st order):
Ey[∇pt(x)T(y+∇U(x)dt)+∂∂tpt(x)(−dt)]=∇pt(x)T(Ey[y]+∇U(x)dt)−∂∂tpt(x)dt=∇pt(x)T∇U(x)dt−∂∂tpt(x)dt
12Ey[(y+∇U(x)dt)T∂2pt(x)∂x2(y+∇U(x)dt)]=12Ey[yt∂2pt(x)∂x2y]+2dt∇U(x)T∂2pt(x)∂x2Ey[y]+o(dt)=12Ey[yt∂2pt(x)∂x2y]+o(dt)=12∑i=j(∂2pt(x)∂x2)iiEy[y2i]+12∑i≠j(∂2pt(x)∂x2)ijEy[yiyj]+o(dt)
=12∑i=j(∂2pt(x)∂x2)iiEy[y2i]+o(dt)=12∑i=j(∂2pt(x)∂x2)iiσ2dt+o(dt)=12∇2pt(x)σ2dt+o(dt)
因此代回去 (14):
pt(x)=(1+div∇U(x)dt)Ey[pt−dt(x+y+∇U(x)dt)]≈(1+div∇U(x)dt)⋅(pt(x)+∇pt(x)T∇U(x)dt−∂∂tpt(x)dt+12∇2pt(x)σ2dt+o(dt))
∂∂tpt(x)=∇pt(x)T∇U(x)+pt(x)div∇U(x)+12σ2∇2pt(x)+o(dt)dt⏟=0
∂∂tpt(x)=∇pt(x)T∇U(x)+pt(x)div∇U(x)+12σ2∇2pt(x)
∂∂tpt(x)=div(pt(x)∇U(x))+12σ2∇2pt(x)