TF Notes (6), Candidate Sampling, Sampled Softmax Loss


NN 做分類最後一層通常使用 softmax loss, 但如果類別數量很大會導致計算 softmax 的 cost 太高, 這樣會讓訓練變得很慢. 假如總共的 class 數量是 10000 個, candidate sampling 的想法就是對於一個 input $x$ 採樣出一個 subset (當然需要包含正確的 label), 譬如只用 50 個 classes, 扣掉正確的那個 class, 剩下的 49 個 classes 從 9999 個採樣出來. 然後計算 softmax 只在那 50 個計算. 那麼問題來了, 這樣的採樣方式最終訓練出來的 logits 會是對的嗎? 它與未採樣前 (full set) 的 logtis 有何對應關係?

採用 candidate sampling 方式的 softmax loss 在 tensorflow 中已經直接有 op 了, 參考 tf.nn.sampled_softmax_loss. 文檔裡最終推導得到如下的一個式子:

$$\begin{align} \log(P(y|x_i,C_i))=\log(P(y|x_i))-\log(Q(y|x_i))+K'(x_i,C_i) \end{align}$$

推導過程自行看文檔就可以, 重要的是了解式子的物理意義.
$C_i$ 是對 input $x_i$ 採樣出的 subset, 包含了 一個正確的類別標籤其他採樣出的類別 $S_i$. $Q(y|x_i)$ 是基於 input $x_i$, label $y$ 被選中成為 $S_i$ 的機率. $K’$ 是跟 $y$ 無關的, 所以對於式子來說是 constant. 注意到式子的變數是 $y$ 代表了是 softmax 的哪一個 output node.

式 (1) 的解釋為: “在 candidate set $C_i$ 下的 logits 結果” 等於 “在 full set 下的 logtis 結果減去 $\log Q(y|x_i)$”, $K’$ 會直接被 $\log P(y|x_i)$ 吸收, 因為 logits 加上 constant 對於 softmax 來說會分子分母消掉, 所以不影響.

以下我們順便複習一下, 為什麼 logits 可以寫成 “$\mbox{const}+\log P(y|x)$” 這種形式. (包含複習 Entropy, cross-entropy, softmax loss)


Entropy 定義

$$\begin{align} \sum_i{q(x_i)\log{\frac{1}{q(x_i)}}} \end{align}$$

對於 input $x_i$, 其機率為 $q(x_i)$, 若我們使用 $\log{\frac{1}{q(x_i)}}$ 這麼多 bits 的數量來 encode 它的話, 則上面的 entropy 代表了 encode 所有 input 所需要的平均 bits 數, 而這個數是最小的.


用錯誤的 encoding 方式

我們假設用 $\log{\frac{1}{p(x_i)}}$ 這麼多 bits 的數量來 encode 的話, 則平均 encode bits 數為:

$$\begin{align} \sum_i{q(x_i)\log{\frac{1}{p(x_i)}}} \end{align}$$

這個數量一定會比 entropy 來的大, 而大出來的值就是我們使用錯誤的 encoding 造成的代價 (cross-entropoy).


Cross-entropy

如上面所說, 錯誤的 encoding 方式造成的代價如下:

$$\begin{align} \mbox{Xent}(p,q)\triangleq\sum_i{q(x_i)\log{\frac{1}{p(x_i)}}} - \sum_i{q(x_i)\log{\frac{1}{q(x_i)}}} \\ =\sum_i{q(x_i)\log{\frac{q(x_i)}{p(x_i)}}} \\ \end{align}$$

Sparse softmax loss

最常見的情形為當只有 $q(x_j)=1$ 而其他 $x\neq x_j$ 時 $q(x)=0$ 的話 ($q$ 變成 one-hot), 上面的 corss-entropy 變成:

$$\begin{align} \mbox{SparseSoftmaxLoss}\triangleq\mbox{Xent}(p,q\mbox{ is one-hot})=-\log p(x_j) \\ =-\log\frac{e^{z_j}}{\sum_i{e^{z_i}}}=-\log e^{z_j} + \log\sum_i{e^{z_i}} \\ =-z_j + \log\sum_i{e^{z_i}} \end{align}$$

其中 $z_i$ 表示 i-th logtis, 參考 tf.nn.sparse_softmax_cross_entropy_with_logits


Logits 的解釋

j-th logtis $z_j$ 可解釋為 “const + class $j$ 的 log probability”.

$$\begin{align} z_j = \mbox{cosnt} + \log p(j) \end{align}$$

為什麼呢? 這是因為 logtis 經過 softmax 後會變成機率, 我們假設經過 softmax 後 node $j$ 的機率為 $p’(j)$, 計算一下這個值:

$$\begin{align} p'(j)=\frac{e^{z_j}}{\sum_i e^{z_i}} \\ =\frac{e^{\log p(j)}e^{\mbox{const}}}{e^{\mbox{const}}\sum_i e^{\log p(i)}} \\ =\frac{p(j)}{\sum_i p(i)} \\ =p(j) \end{align}$$

這時候我們再回去對照開始的式 (1), 就能清楚的解釋 candidate sampling 的 logtis 和 full set 的 logits 之間的關係了.


Sampled softmax loss

由式 (1) 我們已經知道 candidate sampling 的 logtis 和 full set 的 logits 之間的關係. 因此在訓練的時候, 正常 forward propagation 到 logits 時, 這時候的 logits 是 full set 的. 但由於我們計算 softmax 只會在 candidate set 上. 因此要把 full set logits 減去 $\log Q(y|x_i)$, 減完後才會是正確的 candiadtes logits.

對於 inference 部分, 則完全照舊, 因為原本 forward propagation 的結果就是 full set logits 了. 這也是 tf 官網範例這麼寫的原因:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
if mode == "train":
loss = tf.nn.sampled_softmax_loss(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
...,
partition_strategy="div")
elif mode == "eval":
logits = tf.matmul(inputs, tf.transpose(weights))
logits = tf.nn.bias_add(logits, biases)
labels_one_hot = tf.one_hot(labels, n_classes)
loss = tf.nn.softmax_cross_entropy_with_logits(
labels=labels_one_hot,
logits=logits)

Reference

  1. tf.nn.sampled_softmax_loss
  2. Candidate Sampling
  3. tf.nn.sparse_softmax_cross_entropy_with_logits