看完本文會知道什麼是 fake quantization 以及跟 QAT (Quantization Aware Training) 的關聯
同時了解 pytorch 的 torch.ao.quantization.fake_quantize.FakeQuantize
這個 class 做了什麼
Fake quantization 是什麼?
我們知道給定 zero (z) and scale (s) 情況下, float 數值 r 和 integer 數值 q 的關係如下:
r=s(q−z)q=round_to_int(r/s)+zint8
Fake quantization 主要概念就是用 256 個 float 點 (e.g. 用
int8
) 來表示所有 float values, 因此一個 float value 就使用256點中最近的一點 float 來替換則原來的 floating training 流程都不用變, 同時也能模擬因為 quantization 造成的精度損失, 這種訓練方式稱做 Quantization Aware Training (QAT) (See Quantization 的那些事)
令一個 tensor x
如下, 數值參考 pytorch 官方範例 (link):
|
|
同時令 zero and scale 和 integer 為 int8
|
|
則我們可以使用 torch.fake_quantize_per_tensor_affine
(link) 來找出哪一個256點的 float 最接近原來的 x
的 float 值
|
|
其實我們也可以用式 (2) 先算出 quantized 的值, 然後再用 (1) 回算最靠近的 float, 這樣計算應該要跟上面使用 torch.fake_quantize_per_tensor_affine
的結果一樣:
|
|
Fake quantization 必須要能微分
既然要做 QAT, 也就是說在 back propagation 時, fake quantization 這個 function 也要能微分
我們看一下 fake quantization function 長相:
基本上就是一個 step function, 除了在有限的不連續點外, 其餘全部都是平的, 所以 gradient 都是 0.
這導致沒法做 back propagation. 為了讓 gradient 流回去, 我們使用 identity mapping (假裝沒有 fake quantization) 的 gradient:
那讀者可能會問, 這樣 gradient 不就跟沒有 fake quantization 一樣了嗎? 如何模擬 quantization 造成的精度損失?
我們來看看加上 loss 後的情形, 就可以解答這個問題
隨便假設一個 loss function 如下(可以是非常複雜的函數, 例如裡面含有NN):
loss=(x−0.1)2
原來的 training flow 是上圖中的上面子圖, loss function 使用 x 代入計算, 而使用 fake quantization training 的話必須代入 fq_x. 這樣就能在計算 loss 的時候模擬精度損失.
我們觀察一下 gradient:
dlossdx=dlossdfq_x⋅dfq_xdx=2(fq_x−0.1)⋅{0or1}
接續上面的 codes 我們來驗算一下 gradient 是不是如同 (4) 這樣
|
|
注意到 x.grad[-1]
的值是 0, 這是因為 x[-1]
已經小於 quant_min
了, 所以 fake quantization 的 gradient, dfq_xdx=0, 其他情況都是 2(fq_x−0.1).
這個做法跟 so called STE (Straight-Through Estimator) 是一樣的意思 [1], 用來訓練 binary NN [6]
一篇易懂的文章 “Intuitive Explanation of Straight-Through Estimators with PyTorch Implementation“
加入 observer
要做 fake quantization 必須給定 zero and scale (z,s), 而這個值又必須從 input (或說 activation) 的值域分布來統計
因此我們通常會安插一個 observer
來做這件事情
pytorch 提供了不同種類的統計方式來計算 (z,s), 例如:
- MinMaxObserver and MovingAverageMinMaxObserver
- PerChannelMinMaxObserver and MovingAveragePerChannelMinMaxObserver
- HistogramObserver
- FixedQParamsObserver
因此一個完整個 fake quantization 包含了 observer 以及做 fake quantization 的 function, FakeQuantize
這個 pytorch class 就是這個功能:
observer
只是用來給 (z,s) 不需要做 back propagation
但其實 scale s 也可以 learnable! 參考 “Learned Step Size Quantization“ (待讀)
因此我們可以看到要 create FakeQuantize
時, 它的 init 有包含給一個 observer
:
|
|
FakeQuantize
這個 class 是 nn.Module
, 只要 forward
裡面的每個 operation 都有定義 backward
(都可微分), 就自動可以做 back propagation
本文最開頭有展示
torch.fake_quantize_per_tensor_affine
可以做backward
, 是可以微分的 op
最後, 在什麼地方安插 FakeQuantize
會根據不同的 module (e.g. CNN, dethwise CNN, LSTM, GRU, … etc.) 而不同, 同時也必須考量如果有 batch normalization, concate operation, add operation 則會有一些 fusion, requantize 狀況要注意
Figure Backup
Reference
- Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation: STE paper 2013 Yoshua Bengio
- Intuitive Explanation of Straight-Through Estimators with PyTorch Implementation: STE 介紹, 包含用 Pytorch 實作
torch.ao.quantization.fake_quantize.FakeQuantize
(link)torch.fake_quantize_per_tensor_affine
(link)- Learned Step Size Quantization: scale s 也可以 learnable (待讀)
- 二值网络,围绕STE的那些事儿