Learning Zero Point and Scale in Quantization Parameters


在上一篇 搞懂 Quantization Aware Training 中的 Fake Quantization 我們討論了 fake quantization 以及 QAT
提到了 observer 負責計算 zero point and scale (z,s), 一般來說只需要透過統計觀測值的 min/max 範圍就能給定, 所以也不需要參與 backward 計算

直觀上我們希望找到的 zero/scale 使得 quantization error 盡量小, 但其實如果能對任務的 loss 優化, 應該才是最佳的
這就必須讓 (z,s) 參與到 backward 的計算, 這種可以計算 gradient 並更新的做法稱為 learnable quantization parameters

本文主要參考這兩篇論文:
 1. LSQ: Learned Step Size Quantization
 2. LSQ+: Improving low-bit quantization through learnable offsets and better initialization

LSQ 只討論 updating scale, 而 LSQ+ 擴展到 zero point 也能學習, 本文只推導關鍵的 gradients 不說明論文裡的實驗結果

很快定義一下 notations:
 - v: full precision input value
 - s: quantizer step size (scale)
 - z: zero point (offset)
 - QP,QN: the number of positive and negative quantization levels
  e.g.: for b bits, unsigned QN=0,QP=2b1, for signed QN=2b1,QP=2b11
 - x: round x to nearest integer
v quantize 到 ˉv (1), 再將 ˉv dequantize 回 ˆv (2), 而 vˆv 就是 precision loss
ˉv=clip(v/s+z,QN,QP)ˆv=(ˉvz)×s

學習 Scale


因為在 forward 的時候是 ˆv 去參與 Loss L 的計算 (不是 v), 所以計算 s 的 gradient 時 Loss L 必須對 ˆv 去微分, 因此
Ls=Lˆvˆvs

其中 L/ˆv 是做 backprop 時會傳進來的, 所以需要計算 ˆv/s
ˆvs=(ˉvz)ss=sˉvs+ˉvz=s{vs2if QN<v/s+z<QP0otherwise+ˉvz
橘色的地方 ˉv/s 必須使用 STE (Straight Through Estimator) (參考上一篇筆記)
ˉv 用這樣表達:
ˉv={v/s+zif QN<v/s+z<QPQNif v/s+zQNQPif QPv/s+z
所以代回去 (5) 得到我們要的 scale 的 gradients:
ˆvs={v/s+v/sif QN<v/s+z<QPQNzif v/s+zQNQPzif v/s+zQP
在 LSQ 這篇的作者把 gradients ˆv/s 畫出來, 可以看到在 quantization 的 transition 處, LSQ 能體現出 gradient 變動很大 (另外兩個方法沒有)

Scale 的 Gradient 要做調整


LSQ 作者實驗認為 weights 和 scale 的 gradients 大小, 再除以各自的參數數量後, 如果在比例上一樣的話效果比較好:

R=|sLs|/wLw1 要讓更新的相對大小是接近的, 因此會把 gradients 乘上如下的 scale 值: g=1/NQP, 其中 N 是那一層的 (pytorch) tensor 總數量 .numel

Weight tensor W 就是 W.numel, 而如果要處理 scale s 的話, 假設處理的是 activations X, 那就是 X.numel

實作

這個 gradient scale 的技巧很好, 可以用在任何不想改變 output 大小, 而又希望改變 gradient 大小的場合使用

學習 Zero Point


推導 zero point 的 gradient (式子打不出來很怪, 只能用圖片):

對照 PyTorch 實作


Pytorch 實作: _fake_quantize_learnable_per_tensor_affine_backward 裡面註解寫著如下的敘述:
The gradients for scale and zero point are calculated as below:
Let Xfq be the fake quantized version of X.
Let Xq be the quantized version of X (clamped at qmin and qmax).
Let Δ and z be the scale and the zero point.

式子打不出來很怪, 只能用圖片:

可以發現與 gradient of scale (7) 和 gradient of zero point (12) 能對照起來

一些訓練說明


有關 initialization 可以從 post quantization 開始, 不一定要照論文的方式
其中第一和最後一層都使用 8-bits (我覺得甚至用 32-bits 都可以), 這兩層用高精度能使得效果顯著提升, 已經是個標準做法了
另一個標準做法是 intial 都從 full precision 開始

Reference


  1. LSQ: Learned Step Size Quantization
  2. LSQ+: Improving low-bit quantization through learnable offsets and better initialization
  3. 重训练量化·可微量化参数: 有 zero point 的微分推導
  4. Pytorch 實作: _fake_quantize_learnable_per_tensor_affine_backward
  5. 量化训练之可微量化参数—LSQ
  6. 別人的實作: lsq-net: https://github.com/zhutmost/lsq-net/blob/master/quan/quantizer/lsq.py