Domain Adaptation 是希望在 source domain 有 label 但是 target domain 無 label 的情況下, 能針對 target domain (或同時也能對 source domain) 進行分類任務. “Adversarial” 的意思是利用 GAN 的 “對抗” 想法: Label predictor 雖然只能保證 source domain 的分類. 但由於我們把 feature 用 GAN 消除了 domain 之間的差異, 因此我們才能期望這時候的 source domain classifier 也能作用在 target domain.
這篇文章 張文彥, 開頭的圖傳達的意思很精確, 請點進去參考.
接著嘗試複現了一次 Domain-Adversarial Training of Neural Networks 的 mnist(source) to mnist_m(target) 的實驗.
上一篇說明 GAN 的 framework:
$$\begin{align} Div\left(P_d\|P_G\right) = \max_D\left[ E_{x\sim P_d} D(x) - E_{x\sim P_G}f^*(D(x)) \right] \\ G^*=\arg\min_G{Div\left(P_d\|P_G\right)} + reg(G) \\ \end{align}$$對於 Adversarial Domain Adaptation 來說只要在正常 GAN 的 training 流程中, update $G$ 時多加一個 regularization term $reg(G)$ 就可以了. 而 $reg(G)$ 就是 Label Predictor 的 loss, 作用就是 train $G$ 時除了要欺騙 $D$, 同時要能降低 prediction error.
實驗
source domain 為標準的 mnist, 而 target domain 是 modified mnist, 如何產生可以參考Daipuwei/DANN-MNIST.
下圖是 mnist_m 的一些範例:
我們先來看一下分佈, 藍色的點是 mnist, 紅色是 mnist_m, 用 tSNE 跑出來的結果明顯看到兩個 domain 分佈不同:
我們之前說過, 不用管 GRL (Gradient Reversal Layer), 就一般的 GAN 架構, 加上 regularization term 就可以. 聽起來很容易, 我就隨手自己用了幾個 CNN 在 generator, 幾層 fully connected layers 給 classifier 和 discriminator 就做了起來. 發現怎麼弄都訓練不起來! 產生下面兩種情形:
GAN too weak:
重新調整了一下 $reg(G)$ 的比重後….GAN too strong:
兩個 domain 的 features 幾乎完全 overlap, 然後 classifier 幾乎無作用 (也看不出有10個分群). 話說, 這圖很像腦的紋路? 貪食蛇? 迷宮? 肚子裡的蛔蟲?
後來在嘗試調了幾個參數後仍然訓練不起來. 這讓我感到很挫折. 實在受不了後, 參考了網路上的做法改成以下幾點:
- WGAN 改成用 MMGAN
RMSProp(1e-4)
改成Adam(1e-3)
- 使用網路上一個更簡單的架構 github
- 改成用 MMGAN 後, 去掉 BN layer 就能訓練起來
然後就可以訓練起來了(翻桌xN), 訓練後的結果如下:
可以看到在 mnist 辨識率 ~99% 的情形下, mnist_m 能夠有 83.6% 的辨識率 (沒做 adaptation 只有約50%)
Feature 的分布如下圖 (藍色的點是 mnist, 紅色是 mnist_m):
雖然還有一些 feature 沒有完全 match 到, 但已經很重疊了. 同時我們也能明顯到到 10 群的分類.
結論
雖然理論上的理解很容易, 但實作起來卻發現很難調整. GAN 就是那麼難搞阿….
Reference
- GAN framework
- Domain-Adversarial Training of Neural Networks
- 參考產生 mnist_m 的 codes Daipuwei/DANN-MNIST
- Domain-Adversarial Training of Neural Networks with TF2.0: lancerane/Adversarial-domain-adaptation
- 張文彥 Domain-adaptation-on-segmentation
- 自己實驗的 jupyter notebook