在上一篇 搞懂 Quantization Aware Training 中的 Fake Quantization 我們討論了 fake quantization 以及 QAT
提到了 observer
負責計算 zero point and scale $(z,s)$, 一般來說只需要透過統計觀測值的 min/max 範圍就能給定, 所以也不需要參與 backward 計算
直觀上我們希望找到的 zero/scale 使得 quantization error 盡量小, 但其實如果能對任務的 loss 優化, 應該才是最佳的
這就必須讓 $(z,s)$ 參與到 backward 的計算, 這種可以計算 gradient 並更新的做法稱為 learnable quantization parameters
本文主要參考這兩篇論文:
1. LSQ: Learned Step Size Quantization
2. LSQ+: Improving low-bit quantization through learnable offsets and better initialization
LSQ 只討論 updating scale, 而 LSQ+ 擴展到 zero point 也能學習, 本文只推導關鍵的 gradients 不說明論文裡的實驗結果
很快定義一下 notations:
- $v$: full precision input value
- $s$: quantizer step size (scale)
- $z$: zero point (offset)
- $Q_P,Q_N$: the number of positive and negative quantization levels
e.g.: for $b$ bits, unsigned $Q_N=0,Q_P=2^b-1$, for signed $Q_N=2^{b-1},Q_P=2^{b-1}-1$
- $\lfloor x \rceil$: round $x$ to nearest integer
將 $v$ quantize 到 $\bar{v}$ (1), 再將 $\bar{v}$ dequantize 回 $\hat{v}$ (2), 而 $v-\hat{v}$ 就是 precision loss
$$\begin{align}
\bar{v}={clip(\lfloor v/s \rceil+z,-Q_N,Q_P)} \\
\hat{v}=(\bar{v}-z)\times s\\
\end{align}$$