Noise Contrastive Estimation (NCE) 筆記


之前聽人介紹 wav2vec [3] 或是看其他人的文章大部分都只有介紹作法, 直到有一天自己去看論文才發現看不懂 CPC [2] (wav2vec 使用 CPC 方法). 因此才決定好好讀一下並記錄.

先將這些方法關係梳理一下, NCE –> CPC (infoNCE) –> wav2vec. 此篇筆記主要紀錄 NCE (Noise Contrastive Estimation)

在做 ML 時常常需要估計手上 training data 的 distribution pd(x). 而我們通常會使用參數 θ, 使得參數的模型跟 pd(x) 一樣. 在現在 DNN 統治的年代可能會說, 不然就用一個 NN 來訓練吧, 如下圖:

給 input x, 丟給 NN 希望直接吐出 pθ(x). 上圖的架構是 x 先丟給參數為 θf 的 NN, 該 NN 最後一層的 outputs 再丟給參數為 w 的 linear layer 最後吐出一個 scalar 值, 該值就是我們要的機率.
而訓練的話就使用 MLE (Maximum Likelihood Estimation) 來求參數 θ.

恩, 問題似乎很單純但真正實作起來卻困難重重. 一個問題是 NN outputs 若要保持 p.d.f. 則必須過 softmax, 確保 sum 起來是 1 (也就是要算 Zθ).

pθ(x)=uθ(x)Zθ=eG(x;θ)Zθwhere Zθ=xuθ(x)

式 (1) 為 energy-based model, 在做 NN classification 時, NN 的 output 就是 G(x;θ), 也就是常看到的 logit, 經過 softmax 就等同於式 (1) 在做的事

而做這件事情在 x 是 discrete space 但數量很多, 例如 NLP 中 LM vocabulary 很大時, 計算資源會消耗過大.
或是 x 是 continuous space 但是算 Zθ 的積分沒有公式解的情形會做不下去. (不然就要用 sampling 方法, 如 MCMC)

NCE 巧妙的將此 MLE 問題轉化成 binary classification 問題, 從而得到我們要的 MLE 解.

不過在此之前, 我們先來看看 MLE 的 gradient 長什麼樣.


MLE 求解

寫出 likelihood:

likilhood=xpdpθ(x)

Loss 就是 negative log-likelihood

Lmle=Expdlogpθ(x)=Expdloguθ(x)Zθ

計算其 gradient:

θLmle=Expd[θloguθ(x)θlogZθ]θlogZθ=1ZθθZθ=1ZθxθeG(x;θ)=1ZθxeG(x;θ)θG(x;θ)=x[1ZθeG(x;θ)]θG(x;θ)=xpθ(x)θloguθ(x)=Expθθloguθ(x) θLmle=Expd[θloguθ(x)Expθθloguθ(x)]=Expdθloguθ(x)Expθθloguθ(x)=x[pd(x)pθ(x)]θloguθ(x)

從 (11) 式可以看到, 估計的 pdf 與 training data 的 pdf 差越大 gradient 愈大, 當兩者相同時 gradient 為 0 不 update.


Sigmoid or Logistic Function

在說明 NCE 之前先談一下 sigmoid function. 假設現在我們做二分類問題, 兩個類別 C=1 or C=0. 令 p 是某個 input x 屬於 class 1 的機率 (所以 1p 就是屬於 class 0 的機率)
定義 log-odd 為 (其實也稱為 logit):

log-odd=logp1p

我們知道 sigmoid function σ(x)=11+ex 將實數 input mapping 到 0 ~ 1 區間的函式. 若我們將 log-odd 代入我們很容易得到:

σ(log-odd)=...=p

發現 sigmoid 回傳給我們的是 x 屬於 class 1 的機率值, i.e. σ(log-odd)=p(C=1|x). 所以在二分類問題上, 我們就是訓練一個 NN 能 predict logit 值.


NCE 的 Network 架構

首先 NCE 引入了一個 Noise distribution q(x). 論文提到該 q 只要滿足當 pd(x) nonzero 則 q(x) 也必須 nonzero 就可以.

二分類問題為, 假設要取一個正例 (class 1), 就從 training data pdf pd(x) 取得. 而若要取一個反例 (class 0) 則從 noise pdf q(x) 取得.
我們可以取 Np 個正例以及 Nn 個反例, 代表 prior 為:

p(C=1)=NpNp+Nnp(C=0)=1p(C=1)

因此就可以得到一個 batch 共 Np+Nn 個 samples, 丟入下圖的 NN structure 做二分類問題:

Network 前半段還是跟原來的 MLE 架構一樣, 只是我們期望 NNθ 吐出來的是 logit, 由上面一個 section 我們知道經過 sigmoid 得到的會是 x 屬於 class 1 的機率. 因此很容易就用 xent loss 優化.

神奇的來了, NCE 告訴我們, optimize 這個二分類問題得到的 θ 等於 MLE 要找的 θ!

θnce=θmle

且 NN 計算的 logit 直接就變成 MLE 要算的 pθ(x).

同時藉由換成二分類問題, 也避開了很難計算的 Zθ 問題.
為了不影響閱讀流暢度, 推導過程請參照 Appendix

所以我們可以透過引入一個 Noise pdf 來達到估計 training data 的 generative model 了. 這也是為什麼叫做 Noise Contrastive Estimation.


Representation

由於透過 NCE 訓練我們可以得到 θ, 此時只需要用 θf 的 NN 來當作 feature extractor 就可以了.


總結

最後流程可以總結成下面這張圖:

最後聊一下 CPC (Contrastive Predictive Coding) [2]. 我覺得跟 NCE 就兩點不同:

  1. 我們畫的 NCE 圖裡的 w, 改成論文裡的 ct, 所以變成 network 是一個 conditioned 的 network
  2. 不是一個二分類問題, 改成 N 選 1 的分類問題 (batch size N, 指出哪一個是正例), 因此用 categorical cross-entorpy 當 loss

所以文章稱這樣的 loss 為 infoNCE loss

同時 CPC [2] 論文中很棒的一點是將這樣的訓練方式也跟 Mutual Information (MI) 連接起來.
證明了最小化 infoNCE loss 其實就是在最大化 representation 與正例的 MI (的 lower bound).

這些背後數學撐起了整個利用 CPC 在 SSL (Self-Supervised Learning) 的基礎. 簡單講就是不需要昂貴的 label 全部都 unsupervised 就能學到很好的 representation.
而近期 facebook 更利用 SSL 學到的好 representation 結合 GAN 在 ASR 達到了 19 年的 STOA WER. 論文: Unsupervised Speech Recognition or see [9]

SSL 好東西, 不試試看嗎?


Appendix

Prior pdf:
p(C=1)=NpNp+Nnp(C=0)=1p(C=1)

Generative pdf:
p(x|C=1)=pθ(x)p(x|C=0)=q(x)

因此 Posterior pdf:
p(C=1|x)=p(C=1)p(x|C=1)p(C=1)p(x|C=1)+p(C=0)p(x|C=0)=pθ(x)pθ(x)+Nrq(x)p(C=0|x)=p(C=0)p(x|C=0)p(C=1)p(x|C=1)+p(C=0)p(x|C=0)=Nrq(x)pθ(x)+Nrq(x)


其中 Nr=NnNp

因此 likelihood 為:
likilhood=Npt=1p(Ct=1|xt)Nnt=1p(Ct=0|xt)

Loss 為 negative log-likelihood:
Lnce=Npt=1logp(Ct=1|xt)+Nnt=1logp(Ct=0|xt)=Np[1NpNpt=1logp(Ct=1|xt)]+Nn[1NnNnt=0logp(Ct=0|xt)][1NpNpt=1logp(Ct=1|xt)]+Nr[1NnNnt=0logp(Ct=0|xt)]

當固定 Nr 但是讓 Np and Nn. 意味著我們固定正負樣本比例, 但取無窮大的 batch. 重寫上式成:
Lnce=Expdlogp(C=1|x)+NrExqlogp(C=0|x)θLnce=θ[Expdlogpθ(x)pθ(x)+Nrq(x)+NrExqlogNrq(x)pθ(x)+Nrq(x)]=Expdθlogpθ(x)pθ(x)+Nrq(x)+NrExqθlogNrq(x)pθ(x)+Nrq(x)

計算橘色和綠色兩項, 之後再代回來:

θlogpθ(x)pθ(x)+Nrq(x)=θlog11+Nrq(x)pθ(x)=θlog(1+Nrq(x)pθ(x))=11+Nrq(x)pθ(x)θNrq(x)pθ(x)=Nrq(x)1+Nrq(x)pθ(x)θ1pθ(x)=Nrq(x)1+Nrq(x)pθ(x)1p2θ(x)θpθ(x)=Nrq(x)pθ(x)+Nrq(x)[1pθ(x)θpθ(x)]=Nrq(x)pθ(x)+Nrq(x)θlogpθ(x)
θlogNrq(x)pθ(x)+Nrq(x)=θlog(1+pθ(x)Nrq(x))=11+pθ(x)Nrq(x)θpθ(x)Nrq(x)=1Nrq(x)+pθ(x)θpθ(x)=pθ(x)Nrq(x)+pθ(x)[1pθ(x)θpθ(x)]=pθ(x)Nrq(x)+pθ(x)θlogpθ(x)

將 (34), (38) 代回去 (29) 得到:

θLnce=ExpdNrq(x)pθ(x)+Nrq(x)θlogpθ(x)NrExqpθ(x)Nrq(x)+pθ(x)θlogpθ(x)=x[pd(x)Nrq(x)pθ(x)+Nrq(x)θlogpθ(x)]x[q(x)Nrpθ(x)Nrq(x)+pθ(x)θlogpθ(x)]=x(pd(x)pθ(x))Nrq(x)pθ(x)+Nrq(x)θlogpθ(x)=x(pd(x)pθ(x))q(x)pθ(x)Nr+q(x)θlogpθ(x)

Nr 意味著我們讓負樣本遠多於正樣本, 上式變成:
limNrθLnce=x(pd(x)pθ(x))q(x)0+q(x)θlogpθ(x)=x(pd(x)pθ(x))θlogpθ(x)=x[pd(x)pθ(x)](θloguθ(x)θlogZθ)

此時我們發現這 gradient 也與 Noise pdf q(x) 無關了!

最後我們將 MLE and NCE 的 gradient 拉出來對比一下:
θLmle=x[pd(x)pθ(x)]θloguθ(x)θLnce=x[pd(x)pθ(x)](θloguθ(x)θlogZθ)

我們發現 MLE and NCE 只差在一個 normalization factor (or partition) Zθ.
最魔術的地方就在於 NCE 論文 [1] 證明最佳解本身的 logit 已經是 probability 型式, 因此也不需要 normalize factor.

論文裡說礙於篇幅沒給出證明, 主要是來自 Theorem 1 的結果:

所以我們不妨將 Zθ=1, 結果有:

θLmle=θLnceθmle=θnce

Reference

  1. 2010: Noise-contrastive estimation: A new estimation principle for unnormalized statistical models
  2. 2019 DeepMind infoNCE/CPC: Representation learning with contrastive predictive coding
  3. 2019 FB: wav2vec: Unsupervised pre-training for speech recognition
  4. 2020 MIT & Google: Contrastive Representation Distillation
  5. Noise Contrastive Estimation 前世今生——从 NCE 到 InfoNCE
  6. “噪声对比估计”杂谈:曲径通幽之妙
  7. [译] Noise Contrastive Estimation
  8. The infoNCE loss in self-supervised learning
  9. High-performance speech recognition with no supervision at all