LoRAPrune, Pruning Meets Low-Rank Parameter-Efficient Fine-Tuning 筆記


本文是這篇論文 “LoRAPrune: Pruning Meets Low-Rank Parameter-Efficient Fine-Tuning [arxiv]” 的筆記.

一般來說使用 first-order Taylor importance 的 pruning 方法 (下面會介紹此法) 需計算 gradients 來對每個 weight 計算重要性, 然後根據重要性剪枝. 但是現在模型已經愈來愈大, 對所有 weights 都須計算 gradient 的負擔太大.

另一方面, 在 LLM 中對於大模型的 fine tuning 使用 LoRA (PEFT, Parameter Efficient Fine Tuning, 的一種) 來計算 gradients 非常有效率, 原因是對原來的 weights 是 fixed 的, 只 train LoRA 外掛的”少量”參數, 因此只有少量的 gradients 需要計算. 不過我們思考一下, 如果要對已經 prune 的 weights 旁邊外掛 LoRA 的話, LoRA train 完後沒辦法 merge 回去原來的 weights, 因為有可能打亂原本要 prune 的位置. 但是反過來說, 如果先用 LoRA fine tune 完才進行剪枝, 又回到當模型太大而負擔太大沒效率的問題. 況且這樣分兩步驟可能不是很直接, 如果能在 LoRA fine tune 時就能一併考慮某些 weights 會被 prune 的情況下去 fine tune 可能會更好.

如何 pruning 原來的參數又能利用上 LoRA 的效率就是此篇論文的工作.

$$\begin{array}{|c |c |c |} \hline & 能否對原來的參數做剪枝? & 是否很有效率? \\ \hline \text{1st order pruning} & \text{Yes} & \text{No} \\ \hline \text{LoRA} & \text{No} & \text{Yes} \\ \hline \text{LoRAPrune} & \text{Yes} & \text{Yes} \\ \hline \end{array}$$

以下會先介紹 first-order Taylor importance 的 pruning 方法, 再來介紹 LoRA, 最後說明如何取兩者之優點得出此篇的方法: LoRAPrune

First-order Taylor Importance Pruning


對 weight $w_{ij}$ 的 importance score 估計, 是以該 weight 被 prune 掉的話 ($w_{ij}=0$), 對 loss 有多少影響來當依據, 所以:

$(W_0)_{ij}$$w_{ij}$ 表示

$$\begin{align} \mathcal{I}_{ij}=(\mathcal{L}(x,y,W_0)-\mathcal{L}(x,y,W_0|w_{ij}=0))^2 \end{align}$$

複習一下 Taylor expansion
$$f(x)=f(a)+{f'(a)\over 1!}(x-a)+{f''(a)\over 2!}(x-a)^2+{f'''(a)\over 3!}(x-a)^3+...$$ 所以
$$\mathcal{L}(W-\delta W) = \mathcal{L}(W) + \nabla_W \mathcal{L}^T\cdot(-\delta W) + {1\over2}(-\delta W)^T\cdot(\nabla_W^2 \mathcal{L})\cdot(-\delta W) +... \\ \Longrightarrow \mathcal{L}(W)-\mathcal{L}(W-\delta W)= \nabla_W \mathcal{L}^T\cdot(\delta W) - {1\over2}\delta W^T\cdot(\nabla_W^2 \mathcal{L})\cdot\delta W + ...$$

假設二次項之後影響都比一次項小很多, 因此我們可以把參數 $w_{ij}$ 的 importance score 設定成一次項的 power:
(這時的 $\delta W=w_{ij}$)
$$\begin{align} \mathcal{\hat I}_{ij}=\left( {\partial\mathcal{L}\over\partial w_{ij}}w_{ij} \right)^2 \end{align}$$

我們就根據 $\mathcal{\hat I}_{ij}$ 來逐步剪枝不要的參數

LoRA


LoRA (Low-Rank Adaptation) 公式為:
$$\begin{align} z=xW_0+xBA \end{align}$$ 其中 $W_0\in\mathbf{R}^{d\times k}$ 是原來 model 的參數, $A\in\mathbf{R}^{r\times k}$ and $B\in\mathbf{R}^{d\times r}$ 是 LoRA 的兩個 learnable low rank (rank $r$) 參數.
會將 $W_0$ fixed 住, 只學 $A$ and $B$, 且由於 rank $r$ 通常都不大, 因此很有效率. 注意到為了保證 initial 的時候 performance (output) 跟原來一樣, 會將 $B$ initial 成 $0$ matrix ($A$ random Guassian 即可)
學完之後, 可將 $A,B$ 的參數 merge 回 $W_0$, 所以 inference 不會增加額外計算量
$$\begin{align} W=W_0+BA \end{align}$$

LoRAPrune


如果要將 $w_{ij}$ prune 掉的話, 相當於設定 $(BA)_{ij}=-w_{ij}$, 所以 importance score (1) 改寫如下:
$$\begin{align} \mathcal{I}_{ij}=(\mathcal{L}(x,y,W_0)-\mathcal{L}(x,y,W_0|(BA)_{ij}=-w_{ij}))^2 \end{align}$$ 如同上面一樣 first order Taylor approximation 為:

$$\begin{align} \mathcal{\hat I}_{ij}=\left( {\partial\mathcal{L}\over\partial (BA)_{ij}}((BA)_{ij}+w_{ij}) \right)^2 \end{align}$$ 注意到 $W_0$ 是 fixed 住, 而 $A,B$ 才是 learnable parameters, 所以是對 $(BA)_{ij}$ 偏微分
其中由於 SGD update 公式的關係, (6) 的偏微分那項可這麼看待:
$$\begin{align} {\partial\mathcal{L}\over\partial(BA)_{ij}}\propto (BA)_{ij}|_t - (BA)_{ij}|_{t+1} \end{align}$$ $t$ 為當下的 weights, $t+1$ 是要 update 的 SGD iteration, 繼續拆解如下:
$$\begin{align} {\partial\mathcal{L}\over\partial(BA)_{ij}}\propto\left[ B_{i:}A_{:j}- \left(B_{i:}-\frac{\partial\mathcal{L}}{\partial B_{i:}}\right) \left(A_{:j}-\frac{\partial\mathcal{L}}{\partial A_{:j}}\right) \right] \\ =\left[ \frac{\partial\mathcal{L}}{\partial B_{i:}}A_{:j} + B_{i:}\frac{\partial\mathcal{L}}{\partial A_{:j}} - \frac{\partial\mathcal{L}}{\partial B_{i:}}\frac{\partial\mathcal{L}}{\partial A_{:j}} \right] \end{align}$$ 將 (9) 代回 (6) 得到:
$$\begin{align} \mathcal{\hat I}_{ij}=\left( (\nabla B \cdot A + B\cdot\nabla A - \nabla B\cdot\nabla A)\odot(BA+W_0) \right)^2 \end{align}$$

其中 $\odot$ 表示 element-wised 相乘, 到這裡我們發現只使用 $A,B$ 的 gradients, 因此保有了 LoRA 效率的好處.

💡 總結一下精神: 原來所有 weights 的 first-order Taylor importance scores $\mathcal{I}_{ij}$ (式 5) 在 fine tune LoRA 時使用它的”少量”參數的 gradients 來逼近 $\mathcal{\hat I}_{ij}$ (式 10), 這樣計算 importance score 沒效率的情形就能被改善.

Progressive LoRAPrune


在計算 forward and backward 的時候是使用 masking 的方式計算:
$$\begin{align} z=(xW_0+xBA)\odot M \end{align}$$ 其中 $M$ 是 binary mask, 是根據 importance score $\mathcal{\bar I}$ 計算得到, 而 $\mathcal{\bar I}$ 只是個 smoothed 過後的 $\mathcal{\hat I}$ (10) 而已
$$\begin{align} \mathcal{\bar I}|_t=\lambda\mathcal{\bar I}|_{t-1}+(1-\lambda)\mathcal{\hat I}|_t \end{align}$$

注意到由於直接乘 mask $M$, 沒有特別使用 STE 來讓 mask = 0 的地方的 gradient 流通, 因此被 mask 的 $i,j$ 會沒有 gradients, 但其實 $B_{i:}$$A_{:j}$ 還是有機會被其他位置的 gradients 更新到, 例如 $M_{ik}\neq0$$B_{i:}$ 還是會被 update, $M_{lj}\neq0$$A_{:j}$ 也會被 update, 綜合起來 $(BA)_{ij}$ 也被改變了. 也因此就算 $M_{ij}=0$, $w_{ij}$ 還是有敗部復活的機會.

所以 progressive LoRAPrune 流程如下

論文後面有些實驗很有意思, 例如使用 $\frac{\partial\mathcal{L}}{\partial w_{ij}}$ 來替換 (6) 中的 $\frac{\partial\mathcal{L}}{\partial (BA)_{ij}}$. 再請有興趣的讀者自行閱讀論文.

References


  1. LoRAPrune: Pruning Meets Low-Rank Parameter-Efficient Fine-Tuning [arxiv]
  2. LoRA: Low-Rank Adaptation of Large Language Models [arxiv]