Exp of Adversarial Domain Adaptation


Domain Adaptation 是希望在 source domain 有 label 但是 target domain 無 label 的情況下, 能針對 target domain (或同時也能對 source domain) 進行分類任務. “Adversarial” 的意思是利用 GAN 的 “對抗” 想法: Label predictor 雖然只能保證 source domain 的分類. 但由於我們把 feature 用 GAN 消除了 domain 之間的差異, 因此我們才能期望這時候的 source domain classifier 也能作用在 target domain.

這篇文章 張文彥, 開頭的圖傳達的意思很精確, 請點進去參考.

接著嘗試複現了一次 Domain-Adversarial Training of Neural Networks 的 mnist(source) to mnist_m(target) 的實驗.

上一篇說明 GAN 的 framework:

$$\begin{align} Div\left(P_d\|P_G\right) = \max_D\left[ E_{x\sim P_d} D(x) - E_{x\sim P_G}f^*(D(x)) \right] \\ G^*=\arg\min_G{Div\left(P_d\|P_G\right)} + reg(G) \\ \end{align}$$

對於 Adversarial Domain Adaptation 來說只要在正常 GAN 的 training 流程中, update $G$ 時多加一個 regularization term $reg(G)$ 就可以了. 而 $reg(G)$ 就是 Label Predictor 的 loss, 作用就是 train $G$ 時除了要欺騙 $D$, 同時要能降低 prediction error.


實驗

source domain 為標準的 mnist, 而 target domain 是 modified mnist, 如何產生可以參考Daipuwei/DANN-MNIST.

下圖是 mnist_m 的一些範例:

我們先來看一下分佈, 藍色的點是 mnist, 紅色是 mnist_m, 用 tSNE 跑出來的結果明顯看到兩個 domain 分佈不同:

我們之前說過, 不用管 GRL (Gradient Reversal Layer), 就一般的 GAN 架構, 加上 regularization term 就可以. 聽起來很容易, 我就隨手自己用了幾個 CNN 在 generator, 幾層 fully connected layers 給 classifier 和 discriminator 就做了起來. 發現怎麼弄都訓練不起來! 產生下面兩種情形:

  1. GAN too weak:

    重新調整了一下 $reg(G)$ 的比重後….

  2. GAN too strong:

    兩個 domain 的 features 幾乎完全 overlap, 然後 classifier 幾乎無作用 (也看不出有10個分群). 話說, 這圖很像腦的紋路? 貪食蛇? 迷宮? 肚子裡的蛔蟲?

後來在嘗試調了幾個參數後仍然訓練不起來. 這讓我感到很挫折. 實在受不了後, 參考了網路上的做法改成以下幾點:

  1. WGAN 改成用 MMGAN
  2. RMSProp(1e-4) 改成 Adam(1e-3)
  3. 使用網路上一個更簡單的架構 github
  4. 改成用 MMGAN 後, 去掉 BN layer 就能訓練起來

然後就可以訓練起來了(翻桌xN), 訓練後的結果如下:

可以看到在 mnist 辨識率 ~99% 的情形下, mnist_m 能夠有 83.6% 的辨識率 (沒做 adaptation 只有約50%)

Feature 的分布如下圖 (藍色的點是 mnist, 紅色是 mnist_m):

雖然還有一些 feature 沒有完全 match 到, 但已經很重疊了. 同時我們也能明顯到到 10 群的分類.


結論

雖然理論上的理解很容易, 但實作起來卻發現很難調整. GAN 就是那麼難搞阿….


Reference

  1. GAN framework
  2. Domain-Adversarial Training of Neural Networks
  3. 參考產生 mnist_m 的 codes Daipuwei/DANN-MNIST
  4. Domain-Adversarial Training of Neural Networks with TF2.0: lancerane/Adversarial-domain-adaptation
  5. 張文彥 Domain-adaptation-on-segmentation
  6. 自己實驗的 jupyter notebook