筆記 Langevin Dynamics 和 Fokker-Planck Equation 推導


筆記來源 [DeepBayes2019]: Day 5, Lecture 3. Langevin dynamics for sampling and global optimization 前半小時. 非常精彩!

粒子 x follow Langevin dynamics 的話: xx=U(x)dt+N(0,σ2dt)

x 隨時間的機率分布 pt(x) 會滿足 Fokker-Planck equation 這種 Stochastic Differential Equation (SDE) 的形式:
tpt(x)=pt(x)TU(x)+pt(x)divU(x)+12σ22pt(x)
或這麼寫也可以 (用 div(pu)=pTu+pdiv(u) 公式, 更多 divergence/curl 的微分[參考這, or YouTube])
tpt(x)=div(pt(x)U(x))+12σ22pt(x)
而從 F-P equation 我們可以發現最後 t 時某種設定下會有 stationary 分佈.
而如果將要採樣的目標分佈 p(x) 設定成這種 stationary 分佈的話.
由於是 stationary 表示就算繼續 follow Langevin dynamics 讓粒子 x 移動 (更新), 更新後的值仍然滿足目標分佈 p(x), 因此達到採樣效果!
而這也是 Denoising Diffusion Probabilistic Models (DDPM) 做採樣時的方法.

接著詳細記錄 Langevin dynamics, Fokker-Planck equation 推導, 以及 stationary 分佈和採樣方法.

如果讀者知道 Continuity equation 的話, 應該會發現與 F-P equation 非常相似. 它們的關聯可以參考 “Flow Matching for Generative Modeling” 論文的 Appendix D.


Langevin Dynamics

從 Langevin dynamics 出發, 考慮如下的 Stochastic Differential Equation (SDE) 其中 X(t)random process:
dX(t)=U(X(t))dtForce+σdBtrandom fluctuation

其中 BtBrownian motion, (或稱 Wiener process)
對它做離散逼近:
Xt+1Xt=dtU(Xt)+σdtN(0,I)
注意到 Bt+dtBtN(0,dt)=dtN(0,I).
下圖顯示粒子使用 (1) 的移動軌跡 (來源):

這個 SDE 其實跟 Machine Learning 的 Gradient Descent 很有關聯, 改寫一下:
Wt+1Wt=εL(Wt)+σεN(0,I)
其中 Wt 表示第 t 次 iteration 時的 parameter, L 表示 loss function.
會發現就是 gradient descent 公式多一個 random 項.


Fokker-Planck Equation

Fokker-Planck equation 描述了如果 partical x 的移動遵從 Langevin dynamics, 則 density 隨著時間的變化, i.e. tpt(x), 可以被描述出來
tpt(x)=pt(x)TU(x)+pt(x)divU(x)+12σ22pt(x)

其中 div,2divergence 和 Laplace operators. 圖片來源

隨著時間似乎會達到 stationary distribution.
詳細推導請看最後一段 Appendix. 另外補充 F-P equation 與 Continuity equation 的關係可以參考論文的 Appendix D.

Stationary distribution

要怎麼找到這樣的 distribution? 流程就是先假設有 stationary density 且令其為 Gibbs distribution, pG(x), 的形式
pG(x)=1Zexp(U(x)T),whereZ=exp(U(x)T)dx

然後帶入 Fokker-Planck equation 觀察什麼樣的情況會滿足.
我們最終得到 T=σ2/2. 說明了 stationary distribution 的長相為 T=σ2/2 的 Gibbs distribution.

[T=σ2/2 的推導]
我們會利用到 divergence, , 具有 linearity 性質.
pG(x) 帶入到 Fokker-Planck equation 其中 normalization term Z 可以忽略 (會被除掉)
0=(exp(U(x)/T))TU(x)+exp(U(x)/T)U(x)+12σ2exp(U(x)/T)=[exp(U(x)/T)U(x)]+12σ2exp(U(x)/T)=[exp(U(x)/T)U(x)σ22Texp(U(x)/T)U(x)]=[(1σ22T)exp(U(x)/T)U(x)]=0

注意到最後一行由於我們沒有對 potential field U(x) 有所限制
因此 =0 只有可能中括號內 =0, 所以
(1σ22T)exp(U(x)/T)U(x)=01σ22T=0T=σ22

對目標分布做採樣

因此我們知道, 如果使用 Langevin equation 做擴散的話且希望它最終達到 target distribution p(x), 只要定義 U(x)=logp(x) and σ=2, i.e. T=1. 則根據 pG(x) 的定義, pG(x) 會等於我們的 target distribution p(x), 又已知 pG(x) 為 stationary distribution, 所以用 Langevin equation 擴散隨著時間到最後等同於從 pG(x)=p(x) 採樣.
所以可以這麼做 Langevin dynamics sampling (只利用到 score function 做 sampling) (圖片來源) (圖片來源)


Appendix: Derivation of the Fokker-Planck Equation

重複一下 Langevin dynamics
dX(t)=U(X(t))dt+σdBt

離散化:
xx=U(x)dt+N(0,σ2dt)xN(xU(x)dt,σ2dt):=q(x|x)
(圖片來源)

所以 pt(x) 我們可以這麼寫:
pt(x)=ptdt(x)q(x|x)dx
注意到 ptdt(x) 我們也不知道. 但我們知道 q(x|x) 定義在式 (5), 展開來:
q(x|x)=1(2πσ2dt)n/2exp((xxU(x)dt)22σ2dt)
定義:
yxxU(x)dt:=f(x)
則根據 change of variables 我們知道式 (6) 變成
pt(x)=ptdt(x(y))N(y|0,σ2dtI)|xy|dy
就算簡化了 q(x|x) 成上式, 藍色的部分 x(y), |x/y| 我們仍不知道, 必須想辦法從 y (8) 的定義反寫.
先處理比較好做的 |x/y| (謝謝某位讀者來信討論, 幫助推導)
根據定義 y=xxU(x)dt 且我們有 (IA)1=I+A+A2+A3+... 的公式 [ref], 先計算 y/x 得到:
yx=IH(x)dt
其中 H(x)U 的 Hessian matrix, 其第 (i,j) element 為:
Hij=2Uxixj
再計算 inverse:
xy=(yx)1=(IH(x)dt)1=I+H(x)dt+o(dt)
觀察 I+H(x)dt 的 determinant, 簡單用 2×2 矩陣來看 (很容易拓展到 n×n):
|I+H(x)dt|=|1+2Ux21dt2Ux1x2dt2Ux2x1dt1+2Ux22dt|=(1+2Ux21dt)(1+2Ux22dt)+o(dt)=(1+2Ux21dt)+(1+2Ux21dt)2Ux22dt+o(dt)=1+2Ux21dt+2Ux22dt+o(dt)=1+divU(x)dt+o(dt)
其中 div 表示 divergence, 所以得到:
|xy|=1+divU(x)dt+o(dt)
再想辦法把 x(y) 寫出來, 比較複雜, 把 U(x)x 這點做 Taylor expansion:
yxxU(x)dt=xx(U(x)+(xx)U(x)x+o(xx))dt
x 合併展開並整理, 並注意到因為根據 (4), o(xx)dt 這項為 o(dt), 繼續推導:
=(IU(x)xdt)xxU(x)dt+xU(x)xdt+o(dt)x=(IU(x)xdt)1(y+x+U(x)dtxU(x)xdt+o(dt))
利用 (IA)1=I+A+A2+A3+... 的公式 [YouTube with time] [ref]:
(IU(x)xdt)1=I+U(x)xdt+o(dt)
代回去得到
x=(I+U(x)xdt+o(dt))(y+x+U(x)dtU(x)xxdt+o(dt))=y+x+U(x)dtU(x)xxdt+U(x)xydt+U(x)xxdt+o(dt)=y+x+U(x)dt+U(x)xydt+o(dt)
先對 ydt 分析一下到時候代回去
根據 yxx 的定義 (8) 和 (4):
y=xxU(x)dt=U(x)dtN(0,σ2dt)U(x)dt=N(0,σ2dt)
所以
ydt=N(0,σ2dt)dt=dtdtN(0,σ2)=o(dt)
代回去得到
x=y+x+U(x)dt+U(x)xo(dt)+o(dt)x(y)=x+y+U(x)dt+o(dt)
因此 (10), (12) 代回去 (9)
pt(x)=ptdt(x(y))N(y|0,σ2dtI)|xy|dy=(1+divU(x)dt)Ey[ptdt(x+y+U(x)dt)],whereyN(0,σ2dtI)
紅色部分做 Taylor expansion 對 pt(x) 展開:
(0th order): pt(x)
(1st order):
pt(x)T(y+U(x)dt)+tpt(x)(dt)
(2nd order):
12(y+U(x)dt)T2pt(x)x2(y+U(x)dt)
2nd order 還有對 t 的二次微分項, (dt)2(2pt(x)/t2) 由於是 o(dt) 所以可以省略不寫

Ey:
(0th order): 與 y 無關, 是 constant:
Ey[pt(x)]=pt(x)


(1st order):
Ey[pt(x)T(y+U(x)dt)+tpt(x)(dt)]=pt(x)T(Ey[y]+U(x)dt)tpt(x)dt=pt(x)TU(x)dttpt(x)dt
(2nd order):
12Ey[(y+U(x)dt)T2pt(x)x2(y+U(x)dt)]=12Ey[yt2pt(x)x2y]+2dtU(x)T2pt(x)x2Ey[y]+o(dt)=12Ey[yt2pt(x)x2y]+o(dt)=12i=j(2pt(x)x2)iiEy[y2i]+12ij(2pt(x)x2)ijEy[yiyj]+o(dt)
因為 yN(0,σ2dtI), see (7) and (8), 所以第二項為零
=12i=j(2pt(x)x2)iiEy[y2i]+o(dt)=12i=j(2pt(x)x2)iiσ2dt+o(dt)=122pt(x)σ2dt+o(dt)
其中 2Laplace operator.
因此代回去 (14):
pt(x)=(1+divU(x)dt)Ey[ptdt(x+y+U(x)dt)](1+divU(x)dt)(pt(x)+pt(x)TU(x)dttpt(x)dt+122pt(x)σ2dt+o(dt))
展開整理得
tpt(x)=pt(x)TU(x)+pt(x)divU(x)+12σ22pt(x)+o(dt)dt=0
重複一次, 這就是最後的 Fokker-Planck equation:
tpt(x)=pt(x)TU(x)+pt(x)divU(x)+12σ22pt(x)
或這麼寫也可以 (用 div(pu)=pTu+pdiv(u) 公式, 更多 divergence/curl 的微分[參考這, or YouTube])
tpt(x)=div(pt(x)U(x))+12σ22pt(x)
Q.E.D.