Speculative Sampling 【LLM系列 | 训练&推理加速】投机采样
This post provides an overview, implementation, and time complexity analysis of DeepMind's paper Accelerating Large Language Model Decoding with Speculative Sampling.
Code for this blog post can be found at github.com/jaymody/speculative-samlping.
EDIT (Apr 13th, 2023): Updated code and time complexity to avoid the extra forward pass of the draft model (credits to KexinFeng).
Autoregressive Sampling
The standard way of generating text from a language model is with autoregressive sampling, here's the algorithm as defined in the paper:
In code:
def autoregressive_sampling(x, model, N):
n = len(x)
T = len(x) + N
while n < T:
x = np.append(x, sample(model(x)[-1]))
n += 1
return x
Where:
xis a list of integers representing the token ids of the input textmodelis a language model (like GPT-2) that accepts as input a list of token ids of lengthseq_lenand outputs a matrix of probabilities of shape[seq_len, vocab_size].Nis the number of tokens we want to decode.
The time complexity of this algorithm is O(N⋅tmodel):
- N: The number of iterations of our while loop, which is just the number of tokens to decode N.
- tmodel: The time complexity of each iteration in the loop, which is just the time taken for a single forward pass of our model tmodel.
Speculative Sampling
In speculative sampling, we have two models:
- A smaller, faster draft model (e.g. DeepMind's 7B Chinchilla model)
- A larger, slower target model (e.g. DeepMind's 70B Chinchilla model)
The idea is that the draft model speculates what the output is


118

被折叠的 条评论
为什么被折叠?



