NN 做分類最後一層通常使用 softmax loss, 但如果類別數量很大會導致計算 softmax 的 cost 太高, 這樣會讓訓練變得很慢. 假如總共的 class 數量是 10000 個, candidate sampling 的想法就是對於一個 input $x$ 採樣出一個 subset (當然需要包含正確的 label), 譬如只用 50 個 classes, 扣掉正確的那個 class, 剩下的 49 個 classes 從 9999 個採樣出來. 然後計算 softmax 只在那 50 個計算. 那麼問題來了, 這樣的採樣方式最終訓練出來的 logits 會是對的嗎? 它與未採樣前 (full set) 的 logtis 有何對應關係?
採用 candidate sampling 方式的 softmax loss 在 tensorflow 中已經直接有 op 了, 參考 tf.nn.sampled_softmax_loss. 文檔裡最終推導得到如下的一個式子:
$$\begin{align} \log(P(y|x_i,C_i))=\log(P(y|x_i))-\log(Q(y|x_i))+K'(x_i,C_i) \end{align}$$推導過程自行看文檔就可以, 重要的是了解式子的物理意義.
$C_i$ 是對 input $x_i$ 採樣出的 subset, 包含了 一個正確的類別標籤 和 其他採樣出的類別 $S_i$. $Q(y|x_i)$ 是基於 input $x_i$, label $y$ 被選中成為 $S_i$ 的機率. $K’$ 是跟 $y$ 無關的, 所以對於式子來說是 constant. 注意到式子的變數是 $y$ 代表了是 softmax 的哪一個 output node.
式 (1) 的解釋為: “在 candidate set $C_i$ 下的 logits 結果” 等於 “在 full set 下的 logtis 結果減去 $\log Q(y|x_i)$”, $K’$ 會直接被 $\log P(y|x_i)$ 吸收, 因為 logits 加上 constant 對於 softmax 來說會分子分母消掉, 所以不影響.
以下我們順便複習一下, 為什麼 logits 可以寫成 “$\mbox{const}+\log P(y|x)$” 這種形式. (包含複習 Entropy, cross-entropy, softmax loss)
Entropy 定義
$$\begin{align} \sum_i{q(x_i)\log{\frac{1}{q(x_i)}}} \end{align}$$對於 input $x_i$, 其機率為 $q(x_i)$, 若我們使用 $\log{\frac{1}{q(x_i)}}$ 這麼多 bits 的數量來 encode 它的話, 則上面的 entropy 代表了 encode 所有 input 所需要的平均 bits 數, 而這個數是最小的.
用錯誤的 encoding 方式
我們假設用 $\log{\frac{1}{p(x_i)}}$ 這麼多 bits 的數量來 encode 的話, 則平均 encode bits 數為:
$$\begin{align} \sum_i{q(x_i)\log{\frac{1}{p(x_i)}}} \end{align}$$這個數量一定會比 entropy 來的大, 而大出來的值就是我們使用錯誤的 encoding 造成的代價 (cross-entropoy).
Cross-entropy
如上面所說, 錯誤的 encoding 方式造成的代價如下:
$$\begin{align} \mbox{Xent}(p,q)\triangleq\sum_i{q(x_i)\log{\frac{1}{p(x_i)}}} - \sum_i{q(x_i)\log{\frac{1}{q(x_i)}}} \\ =\sum_i{q(x_i)\log{\frac{q(x_i)}{p(x_i)}}} \\ \end{align}$$Sparse softmax loss
最常見的情形為當只有 $q(x_j)=1$ 而其他 $x\neq x_j$ 時 $q(x)=0$ 的話 ($q$ 變成 one-hot), 上面的 corss-entropy 變成:
$$\begin{align} \mbox{SparseSoftmaxLoss}\triangleq\mbox{Xent}(p,q\mbox{ is one-hot})=-\log p(x_j) \\ =-\log\frac{e^{z_j}}{\sum_i{e^{z_i}}}=-\log e^{z_j} + \log\sum_i{e^{z_i}} \\ =-z_j + \log\sum_i{e^{z_i}} \end{align}$$其中 $z_i$ 表示 i-th logtis, 參考 tf.nn.sparse_softmax_cross_entropy_with_logits
Logits 的解釋
j-th logtis $z_j$ 可解釋為 “const + class $j$ 的 log probability”.
$$\begin{align} z_j = \mbox{cosnt} + \log p(j) \end{align}$$為什麼呢? 這是因為 logtis 經過 softmax 後會變成機率, 我們假設經過 softmax 後 node $j$ 的機率為 $p’(j)$, 計算一下這個值:
$$\begin{align} p'(j)=\frac{e^{z_j}}{\sum_i e^{z_i}} \\ =\frac{e^{\log p(j)}e^{\mbox{const}}}{e^{\mbox{const}}\sum_i e^{\log p(i)}} \\ =\frac{p(j)}{\sum_i p(i)} \\ =p(j) \end{align}$$這時候我們再回去對照開始的式 (1), 就能清楚的解釋 candidate sampling 的 logtis 和 full set 的 logits 之間的關係了.
Sampled softmax loss
由式 (1) 我們已經知道 candidate sampling 的 logtis 和 full set 的 logits 之間的關係. 因此在訓練的時候, 正常 forward propagation 到 logits 時, 這時候的 logits 是 full set 的. 但由於我們計算 softmax 只會在 candidate set 上. 因此要把 full set logits 減去 $\log Q(y|x_i)$, 減完後才會是正確的 candiadtes logits.
對於 inference 部分, 則完全照舊, 因為原本 forward propagation 的結果就是 full set logits 了. 這也是 tf 官網範例這麼寫的原因:
|
|