之前聽人介紹 wav2vec [3] 或是看其他人的文章大部分都只有介紹作法, 直到有一天自己去看論文才發現看不懂 CPC [2] (wav2vec 使用 CPC 方法). 因此才決定好好讀一下並記錄.
先將這些方法關係梳理一下, NCE –> CPC (infoNCE) –> wav2vec. 此篇筆記主要紀錄 NCE (Noise Contrastive Estimation)
在做 ML 時常常需要估計手上 training data 的 distribution pd(x). 而我們通常會使用參數 θ, 使得參數的模型跟 pd(x) 一樣. 在現在 DNN 統治的年代可能會說, 不然就用一個 NN 來訓練吧, 如下圖:
給 input x, 丟給 NN 希望直接吐出 pθ(x). 上圖的架構是 x 先丟給參數為 θf 的 NN, 該 NN 最後一層的 outputs 再丟給參數為 w 的 linear layer 最後吐出一個 scalar 值, 該值就是我們要的機率.
而訓練的話就使用 MLE (Maximum Likelihood Estimation) 來求參數 θ.
恩, 問題似乎很單純但真正實作起來卻困難重重. 一個問題是 NN outputs 若要保持 p.d.f. 則必須過 softmax, 確保 sum 起來是 1 (也就是要算 Zθ).
pθ(x)=uθ(x)Zθ=eG(x;θ)Zθwhere Zθ=∑xuθ(x)式 (1) 為 energy-based model, 在做 NN classification 時, NN 的 output 就是 G(x;θ), 也就是常看到的 logit, 經過 softmax 就等同於式 (1) 在做的事
而做這件事情在 x 是 discrete space 但數量很多, 例如 NLP 中 LM vocabulary 很大時, 計算資源會消耗過大.
或是 x 是 continuous space 但是算 Zθ 的積分沒有公式解的情形會做不下去. (不然就要用 sampling 方法, 如 MCMC)
NCE 巧妙的將此 MLE 問題轉化成 binary classification 問題, 從而得到我們要的 MLE 解.
不過在此之前, 我們先來看看 MLE 的 gradient 長什麼樣.
MLE 求解
寫出 likelihood:
likilhood=∏x∼pdpθ(x)Loss 就是 negative log-likelihood
−Lmle=Ex∼pdlogpθ(x)=Ex∼pdloguθ(x)Zθ計算其 gradient:
−∇θLmle=Ex∼pd[∇θloguθ(x)−∇θlogZθ]∇θlogZθ=1Zθ∇θZθ=1Zθ∑x∇θeG(x;θ)=1Zθ∑xeG(x;θ)∇θG(x;θ)=∑x[1ZθeG(x;θ)]∇θG(x;θ)=∑xpθ(x)∇θloguθ(x)=Ex∼pθ∇θloguθ(x)∴ −∇θLmle=Ex∼pd[∇θloguθ(x)−Ex∼pθ∇θloguθ(x)]=Ex∼pd∇θloguθ(x)−Ex∼pθ∇θloguθ(x)=∑x[pd(x)−pθ(x)]∇θloguθ(x)從 (11) 式可以看到, 估計的 pdf 與 training data 的 pdf 差越大 gradient 愈大, 當兩者相同時 gradient 為 0 不 update.
Sigmoid or Logistic Function
在說明 NCE 之前先談一下 sigmoid function. 假設現在我們做二分類問題, 兩個類別 C=1 or C=0. 令 p 是某個 input x 屬於 class 1 的機率 (所以 1−p 就是屬於 class 0 的機率)
定義 log-odd 為 (其實也稱為 logit):
我們知道 sigmoid function σ(x)=11+e−x 將實數 input mapping 到 0 ~ 1 區間的函式. 若我們將 log-odd 代入我們很容易得到:
σ(log-odd)=...=p發現 sigmoid 回傳給我們的是 x 屬於 class 1 的機率值, i.e. σ(log-odd)=p(C=1|x). 所以在二分類問題上, 我們就是訓練一個 NN 能 predict logit 值.
NCE 的 Network 架構
首先 NCE 引入了一個 Noise distribution q(x). 論文提到該 q 只要滿足當 pd(x) nonzero 則 q(x) 也必須 nonzero 就可以.
二分類問題為, 假設要取一個正例 (class 1), 就從 training data pdf pd(x) 取得. 而若要取一個反例 (class 0) 則從 noise pdf q(x) 取得.
我們可以取 Np 個正例以及 Nn 個反例, 代表 prior 為:
因此就可以得到一個 batch 共 Np+Nn 個 samples, 丟入下圖的 NN structure 做二分類問題:
Network 前半段還是跟原來的 MLE 架構一樣, 只是我們期望 NNθ 吐出來的是 logit, 由上面一個 section 我們知道經過 sigmoid 得到的會是 x 屬於 class 1 的機率. 因此很容易就用 xent loss 優化.
神奇的來了, NCE 告訴我們, optimize 這個二分類問題得到的 θ 等於 MLE 要找的 θ!
θnce=θmle且 NN 計算的 logit 直接就變成 MLE 要算的 pθ(x).
同時藉由換成二分類問題, 也避開了很難計算的 Zθ 問題.
為了不影響閱讀流暢度, 推導過程請參照 Appendix
所以我們可以透過引入一個 Noise pdf 來達到估計 training data 的 generative model 了. 這也是為什麼叫做 Noise Contrastive Estimation.
Representation
由於透過 NCE 訓練我們可以得到 θ, 此時只需要用 θf 的 NN 來當作 feature extractor 就可以了.
總結
最後流程可以總結成下面這張圖:
最後聊一下 CPC (Contrastive Predictive Coding) [2]. 我覺得跟 NCE 就兩點不同:
- 我們畫的 NCE 圖裡的 w, 改成論文裡的 ct, 所以變成 network 是一個 conditioned 的 network
- 不是一個二分類問題, 改成 N 選 1 的分類問題 (batch size N, 指出哪一個是正例), 因此用 categorical cross-entorpy 當 loss
所以文章稱這樣的 loss 為 infoNCE loss
同時 CPC [2] 論文中很棒的一點是將這樣的訓練方式也跟 Mutual Information (MI) 連接起來.
證明了最小化 infoNCE loss 其實就是在最大化 representation 與正例的 MI (的 lower bound).
這些背後數學撐起了整個利用 CPC 在 SSL (Self-Supervised Learning) 的基礎. 簡單講就是不需要昂貴的 label 全部都 unsupervised 就能學到很好的 representation.
而近期 facebook 更利用 SSL 學到的好 representation 結合 GAN 在 ASR 達到了 19 年的 STOA WER. 論文: Unsupervised Speech Recognition or see [9]
SSL 好東西, 不試試看嗎?
Appendix
Prior pdf:
p(C=1)=NpNp+Nnp(C=0)=1−p(C=1)
Generative pdf:
p(x|C=1)=pθ(x)p(x|C=0)=q(x)
因此 Posterior pdf:
p(C=1|x)=p(C=1)p(x|C=1)p(C=1)p(x|C=1)+p(C=0)p(x|C=0)=pθ(x)pθ(x)+Nrq(x)p(C=0|x)=p(C=0)p(x|C=0)p(C=1)p(x|C=1)+p(C=0)p(x|C=0)=Nrq(x)pθ(x)+Nrq(x)
其中 Nr=NnNp
因此 likelihood 為:
likilhood=Np∏t=1p(Ct=1|xt)⋅Nn∏t=1p(Ct=0|xt)
Loss 為 negative log-likelihood:
−Lnce=Np∑t=1logp(Ct=1|xt)+Nn∑t=1logp(Ct=0|xt)=Np[1NpNp∑t=1logp(Ct=1|xt)]+Nn[1NnNn∑t=0logp(Ct=0|xt)]∝[1NpNp∑t=1logp(Ct=1|xt)]+Nr[1NnNn∑t=0logp(Ct=0|xt)]
當固定 Nr 但是讓 Np→∞ and Nn→∞. 意味著我們固定正負樣本比例, 但取無窮大的 batch. 重寫上式成:
−Lnce=Ex∼pdlogp(C=1|x)+NrEx∼qlogp(C=0|x)∴−∇θLnce=∇θ[Ex∼pdlogpθ(x)pθ(x)+Nrq(x)+NrEx∼qlogNrq(x)pθ(x)+Nrq(x)]=Ex∼pd∇θlogpθ(x)pθ(x)+Nrq(x)+NrEx∼q∇θlogNrq(x)pθ(x)+Nrq(x)
計算橘色和綠色兩項, 之後再代回來:
∇θlogpθ(x)pθ(x)+Nrq(x)=∇θlog11+Nrq(x)pθ(x)=−∇θlog(1+Nrq(x)pθ(x))=−11+Nrq(x)pθ(x)∇θNrq(x)pθ(x)=−Nrq(x)1+Nrq(x)pθ(x)∇θ1pθ(x)=−Nrq(x)1+Nrq(x)pθ(x)−1p2θ(x)∇θpθ(x)=Nrq(x)pθ(x)+Nrq(x)[1pθ(x)∇θpθ(x)]=Nrq(x)pθ(x)+Nrq(x)∇θlogpθ(x)將 (34), (38) 代回去 (29) 得到:
−∇θLnce=Ex∼pdNrq(x)pθ(x)+Nrq(x)∇θlogpθ(x)−NrEx∼qpθ(x)Nrq(x)+pθ(x)∇θlogpθ(x)=∑x[pd(x)Nrq(x)pθ(x)+Nrq(x)∇θlogpθ(x)]−∑x[q(x)Nrpθ(x)Nrq(x)+pθ(x)∇θlogpθ(x)]=∑x(pd(x)−pθ(x))Nrq(x)pθ(x)+Nrq(x)∇θlogpθ(x)=∑x(pd(x)−pθ(x))q(x)pθ(x)Nr+q(x)∇θlogpθ(x)當 Nr→∞ 意味著我們讓負樣本遠多於正樣本, 上式變成:
limNr→∞−∇θLnce=∑x(pd(x)−pθ(x))q(x)0+q(x)∇θlogpθ(x)=∑x(pd(x)−pθ(x))∇θlogpθ(x)=∑x[pd(x)−pθ(x)](∇θloguθ(x)−∇θlogZθ)
此時我們發現這 gradient 也與 Noise pdf q(x) 無關了!
最後我們將 MLE and NCE 的 gradient 拉出來對比一下:
−∇θLmle=∑x[pd(x)−pθ(x)]∇θloguθ(x)−∇θLnce=∑x[pd(x)−pθ(x)](∇θloguθ(x)−∇θlogZθ)
我們發現 MLE and NCE 只差在一個 normalization factor (or partition) Zθ.
最魔術的地方就在於 NCE 論文 [1] 證明最佳解本身的 logit 已經是 probability 型式, 因此也不需要 normalize factor.
所以我們不妨將 Zθ=1, 結果有:
∇θLmle=∇θLnce⇒θmle=θnceReference
- 2010: Noise-contrastive estimation: A new estimation principle for unnormalized statistical models
- 2019 DeepMind infoNCE/CPC: Representation learning with contrastive predictive coding
- 2019 FB: wav2vec: Unsupervised pre-training for speech recognition
- 2020 MIT & Google: Contrastive Representation Distillation
- Noise Contrastive Estimation 前世今生——从 NCE 到 InfoNCE
- “噪声对比估计”杂谈:曲径通幽之妙
- [译] Noise Contrastive Estimation
- The infoNCE loss in self-supervised learning
- High-performance speech recognition with no supervision at all