Sharp V.S. Flat Local Minimum 的泛化能力
先簡單介紹這篇文章:
On large-batch training for deep learning: Generalization gap and sharp minima
考慮下圖兩個 minimum, 對於 training loss 來說其 losses 一樣.
從圖可以容易理解到, 如果找到太 sharp 的點, 由於 test and train 的 mismatch, 會導致測試的時候 data 一點偏移就會對 model output 影響很大.
論文用實驗的方式, 去評量一個 local minimum 的 sharpness 程度, 簡單說利用 random perturb 到附近其他點, 然後看看該點 loss 變化的程度如何, 變化愈大, 代表該 local minimum 可能愈 sharp.
然後找兩個 local minimums, 一個估出來比較 sharp 另一個比較 flat. 接著對這兩點連成的線, 線上的參數值對應的 loss 劃出圖來, 長相如下:
這也是目前一個普遍的認知: flat 的 local minimum 泛化能力較好.
所以可以想像, step size (learning rate) 如果愈大, 愈有可能跳出 sharp minimum.
而 batch size 愈小, 表示 gradient 因為 mini-batch 造成的 noise 愈大, 相當於愈有可能”亂跑”跑出 sharp minimum.
但這篇文章僅止於實驗性質上的驗證. Step size and batch size 對於泛化能力, 或是說對於找到比較 flat optimum 的機率會不會比較高? 兩者有什麼關聯呢?
DeepMind 的近期 (2021) 兩篇文章給出了很漂亮的理論分析.