NN 做分類最後一層通常使用 softmax loss, 但如果類別數量很大會導致計算 softmax 的 cost 太高, 這樣會讓訓練變得很慢. 假如總共的 class 數量是 10000 個, candidate sampling 的想法就是對於一個 input x 採樣出一個 subset (當然需要包含正確的 label), 譬如只用 50 個 classes, 扣掉正確的那個 class, 剩下的 49 個 classes 從 9999 個採樣出來. 然後計算 softmax 只在那 50 個計算. 那麼問題來了, 這樣的採樣方式最終訓練出來的 logits 會是對的嗎? 它與未採樣前 (full set) 的 logtis 有何對應關係?
採用 candidate sampling 方式的 softmax loss 在 tensorflow 中已經直接有 op 了, 參考 tf.nn.sampled_softmax_loss. 文檔裡最終推導得到如下的一個式子:
log(P(y|xi,Ci))=log(P(y|xi))−log(Q(y|xi))+K′(xi,Ci)推導過程自行看文檔就可以, 重要的是了解式子的物理意義.
Ci 是對 input xi 採樣出的 subset, 包含了 一個正確的類別標籤 和 其他採樣出的類別 Si. Q(y|xi) 是基於 input xi, label y 被選中成為 Si 的機率. K′ 是跟 y 無關的, 所以對於式子來說是 constant. 注意到式子的變數是 y 代表了是 softmax 的哪一個 output node.
式 (1) 的解釋為: “在 candidate set Ci 下的 logits 結果” 等於 “在 full set 下的 logtis 結果減去 logQ(y|xi)”, K′ 會直接被 logP(y|xi) 吸收, 因為 logits 加上 constant 對於 softmax 來說會分子分母消掉, 所以不影響.
以下我們順便複習一下, 為什麼 logits 可以寫成 “const+logP(y|x)” 這種形式. (包含複習 Entropy, cross-entropy, softmax loss)
Entropy 定義
∑iq(xi)log1q(xi)對於 input xi, 其機率為 q(xi), 若我們使用 log1q(xi) 這麼多 bits 的數量來 encode 它的話, 則上面的 entropy 代表了 encode 所有 input 所需要的平均 bits 數, 而這個數是最小的.
用錯誤的 encoding 方式
我們假設用 log1p(xi) 這麼多 bits 的數量來 encode 的話, 則平均 encode bits 數為:
∑iq(xi)log1p(xi)這個數量一定會比 entropy 來的大, 而大出來的值就是我們使用錯誤的 encoding 造成的代價 (cross-entropoy).
Cross-entropy
如上面所說, 錯誤的 encoding 方式造成的代價如下:
Xent(p,q)≜∑iq(xi)log1p(xi)−∑iq(xi)log1q(xi)=∑iq(xi)logq(xi)p(xi)Sparse softmax loss
最常見的情形為當只有 q(xj)=1 而其他 x≠xj 時 q(x)=0 的話 (q 變成 one-hot), 上面的 corss-entropy 變成:
SparseSoftmaxLoss≜Xent(p,q is one-hot)=−logp(xj)=−logezj∑iezi=−logezj+log∑iezi=−zj+log∑iezi其中 zi 表示 i-th logtis, 參考 tf.nn.sparse_softmax_cross_entropy_with_logits
Logits 的解釋
j-th logtis zj 可解釋為 “const + class j 的 log probability”.
zj=cosnt+logp(j)為什麼呢? 這是因為 logtis 經過 softmax 後會變成機率, 我們假設經過 softmax 後 node j 的機率為 p′(j), 計算一下這個值:
p′(j)=ezj∑iezi=elogp(j)econsteconst∑ielogp(i)=p(j)∑ip(i)=p(j)這時候我們再回去對照開始的式 (1), 就能清楚的解釋 candidate sampling 的 logtis 和 full set 的 logits 之間的關係了.
Sampled softmax loss
由式 (1) 我們已經知道 candidate sampling 的 logtis 和 full set 的 logits 之間的關係. 因此在訓練的時候, 正常 forward propagation 到 logits 時, 這時候的 logits 是 full set 的. 但由於我們計算 softmax 只會在 candidate set 上. 因此要把 full set logits 減去 logQ(y|xi), 減完後才會是正確的 candiadtes logits.
對於 inference 部分, 則完全照舊, 因為原本 forward propagation 的結果就是 full set logits 了. 這也是 tf 官網範例這麼寫的原因:
|
|