這是 Transformer inference 的加速, 有人猜測 GPT-4 也使用這個方法: https://archive.ph/2RQ8X
Speculative decoding 做到了不影響準確率情況下直接加速 (不改 model 架構, 不 fine tune, 不做 PTQ 等)
這麼神奇的操作就是利用了一個小模型來先跑一些 tokens, 再由原來的大模型評估或修正.
論文顯示 LLM 效果無損直接可提速 2~3 倍, 讓我們看下去
Motivation
使用 SongHan 教授的課程 slides. 利用 small model 先提出一些 draft tokens, 然後用 large model 來驗證. 如果大部分都接受, 直覺上可以省去很多 large model 的呼叫次數, 因此加速. 方法十分簡單, 不過其實魔鬼藏在細節裡, 跟原本只使用 large model 的方法比較有幾個問題要回答:
A. 速度的分析: 加速到什麼程度? 跟小模型的速度和準確度有關聯嗎? (想像如果 draft 一直被拒絕, 則小模型都是多跑的)
B. 運算量的分析: Operation 數 (計算量) 也會減少嗎? 還是會增加?
C. Memory bandwidth 的分析: 會減少還是增加?
D. Performance 能維持住嗎 (PPL, WER, BLEU, … 端看 model task 是什麼): 還是會有 degrade?
Google 這篇論文很精彩的理論分析了以上所有問題, 並有實務驗證
先破題, performance (PPL, WER, BLEU, …) 可以保證維持住! 我們等到本篇筆記最後在討論, 以下會先討論算法流程、加速和運算量的分析.
Speculative Decoding 算法流程
使用論文的術語, 例如上面說的 small model 改稱 approximation model, large model 改稱 target model, draft 用 proposal tokens.
Approx. model, Mq, 用 auto-regressive 方式產生 γ 個 proposal tokens {x1,...,xγ} 和機率分布 {q1(x),...,qγ(x)}, 接著把 proposal token 結合上次一的 prefix tokens (但這裡我們為了簡化先忽略 prefix) 給 target model, Mp, 做一次 non-autoregressive forward (parallel) 跑出機率分布 {p1(x),...,pγ(x),pγ+1(x)}.
比較 p(x),q(x) 來決定是否接受 proposal tokens, 如果 p(x)≥q(x) 則採用 Mq 的 proposal token, 否則有 p(x)/q(x) 機率仍會接受 proposal token, 有 1−p(x)/q(x) 的機率要根據修改的機率分布 p′(x)=norm(max(0,p(x)−q(x))) 重新採樣 token.
另外如果所有 γ 個 proposal tokens 都被接受了, 則直接從 target model 的 pγ+1(x) 採樣token.
以上為一個 step or run, 重複直到句子產生結束.
參考下圖:
- 注意到 {p1(x),...,pγ(x),pγ+1(x)} 是一次 forward 就跑出來的, 相比 auto-regressive 的方式要跑 γ 次 forward (load γ 次 model 參數), 現在只需要 load 一次參數(and kv-cache)因此可以節省 memory bandwidth. 但注意到這兩種方式的總計算量是不變的.
- 一般來說 Mp 的輸入會結合上一次 decode 的 tokens (稱 prefix) 加上 Mq 的 proposal tokens 當輸入, 但是這些 prefix 由於上一次 decode 時 forward 過, 在使用 kv-cache 的技巧下可以省略計算.
速度和運算量的初步分析
先定義 E(#generated tokens) 表示 speculative decoding 平均一個 run 可以產生多少”有效” tokens (因為不是所有 proposal tokens 都會被接受)
推論速度 (Walltime) 變化?
每一個 run 需要的時間為 Tcγ+T, 其中 T 是跑一次 target model 所花的時間, c (cost coefficient) 是 approx. model 跟 target model 的時間比 (愈小表示 approx. model 跑愈快). 所以:
- speculative decoding 花了 Tcγ+T 的時間產生 E(#generated tokens) 個 tokens
- 只用 target model 花了 T 的時間產生 1 個 token
因此只要知道 E(#generated tokens) 我們可推得使用 speculative decoding 的速度提升 (walltime improvement):
Walltime Improvement=Speculative decoding (tokens per time)Mp decoding (tokens per time)=E(#generated tokens)/(Tcγ+T)1/T=E(#generated tokens)(cγ+1)
運算量的變化?
定義 ˆT 是 target model ”per token” 的運算量, 而 ˆc 是 approx. model 跟 target model 的運算量比. 每一次的 run, approx. model 會 auto-regressive γ 次, 所以是 ˆTˆcγ, 而 target model 會對 γ 個 proposal tokens parallel 去跑 1 次, 注意到雖然是 parallel, 但總運算次數是正比於 proposal token 數量的 (只是並行跑), 所以花的運算量為 ˆT(γ+1). 所以:
- speculative decoding 花了 ˆTˆcγ+ˆT(γ+1) 運算量產生 E(#generated tokens) 個 tokens
- 只用 target model 花了 ˆT 的運算量產生 1 個 token
同樣只要知道 E(#generated tokens) 我們可推得運算量的變化.
PS: 注意到 prefix tokens 不會花到運算量因為 kv-cache 技巧, 所以考慮的時候可以忽略.
#OPs Increasing Ratio=Speculative decoding (\#OPs per token)Mp decoding (\#OPs per token)=(ˆTˆcγ+ˆT(γ+1))/E(#generated tokens)ˆT/1=ˆcγ+γ+1E(#generated tokens)
平均生成的 Tokens 數
Proposal Tokens 被接受的機率 βt,α
綜上所述, 需要先計算 E(#generated tokens), 等同於要計算 token 的 accept 機率我們才能得知速度以及運算量的變化.
將 proposal token xt∼q(x|x1,...,xt−1)=:qt(x) 被 speculative decoding 接受的機率定義為 βt.
數學上可以麼表達 (為了清楚, 在沒有混淆情況下省略下標 t):
β=Ex∼q(x){1q(x)≤p(x)p(x)q(x)q(x)>p(x)=Ex∼q(x)min(1,p(x)q(x))=∑xmin(p(x),q(x))
所以可以簡化為定義
α:=Et[βt]=∑t∑xmin(pt(x),qt(x)) 有趣的是, 以T5系列的 models 來看, Mq 選擇 bi-gram 這種非常簡單的 LM α 還有 0.2, 代表 bi-gram model 的 proposal tokens 平均5個有1個會被接受.
如果 approx. model 跟 target model 愈匹配的話, accept rate (βt,α) 就會愈高
因此 βt 或 α 可以看成是小模型跟大模型的匹配程度.
但是再繼續之前, 我們必須先回顧一下幾何分佈
Geometric distribution with capped number of trails
考慮一次測試 (trail) 的成功機率為 θ, 最多測試 n 次 trails, random variable X 代表要花多少次的 trails 才會至少成功一次. 注意到如果前 n−1 次都 fail, 則強制最後第 n 次一定成功.
前 n−1 次至少會 success 一次所需花的 trails 次數期望值為:
1×第一次就成功的機率+2×第二次才就成功的機率+...+(n−1)×第(n−1)次才成功的機率
θn−1∑x=1x(1−θ)x−1=θn−1∑x=1(−ddθ(1−θ)x)=−θddθ(n−1∑x=1(1−θ)x)=−θddθ((1−θ)(1−(1−θ)n−1)1−(1−θ))=−θddθ((1−θ)−(1−θ)nθ)=−θθ(−1+n(1−θ)n−1)−(1−θ)+(1−θ)nθ2=θ−nθ(1−θ)n−1+(1−θ)−(1−θ)nθ
E[X]=θ−nθ(1−θ)n−1+(1−θ)−(1−θ)nθ+n(1−θ)n−1=1−(1−θ)nθ
計算平均生成的 tokens 數
E(#generated tokens) 相當於要計算試驗次數有上限 (capped number of trails) 的 geometric distribution 的期望值.
對應到 speculative decoding 的問題裡 θ=1−α, 且試驗次數最多 γ+1 次., 因此將 θ=1−α, n=γ+1 代入得到:
E[#generated tokens]=1−αγ+11−α 我們發現 Mq 與 Mp 愈匹配的話, speculative decoding 一次 run 產生的 tokens 愈多 (很合理, 因為被接受的機率愈高)
產生的 tokens 上限就是 γ+1 (γ 個 proposal tokens 全被接受加上最後一個 Mp 產生的 token)
待續 …
References
- Google: Fast Inference from Transformers via Speculative Decoding [arvix]
- DeepMind: Accelerating Large Language Model Decoding with Speculative Sampling [arxiv]
- Speculative_sampling.drawio
- Speculative Decoding 詳讀 (下)