Speculative Sampling

Speculative Sampling 【LLM系列 | 训练&推理加速】投机采样

This post provides an overview, implementation, and time complexity analysis of DeepMind's paper Accelerating Large Language Model Decoding with Speculative Sampling.

Code for this blog post can be found at github.com/jaymody/speculative-samlping.

EDIT (Apr 13th, 2023): Updated code and time complexity to avoid the extra forward pass of the draft model (credits to KexinFeng).

Autoregressive Sampling

The standard way of generating text from a language model is with autoregressive sampling, here's the algorithm as defined in the paper:

In code:

def autoregressive_sampling(x, model, N):
    n = len(x)
    T = len(x) + N

    while n < T:
        x = np.append(x, sample(model(x)[-1]))
        n += 1

    return x

Where:

  • x is a list of integers representing the token ids of the input text
  • model is a language model (like GPT-2) that accepts as input a list of token ids of length seq_len and outputs a matrix of probabilities of shape [seq_len, vocab_size].
  • N is the number of tokens we want to decode.

The time complexity of this algorithm is O(N⋅tmodel):

  • N: The number of iterations of our while loop, which is just the number of tokens to decode N.
  • tmodel: The time complexity of each iteration in the loop, which is just the time taken for a single forward pass of our model tmodel.

Speculative Sampling

In speculative sampling, we have two models:

  1. A smaller, faster draft model (e.g. DeepMind's 7B Chinchilla model)
  2. A larger, slower target model (e.g. DeepMind's 70B Chinchilla model)

The idea is that the draft model speculates what the output is 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

张博208

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值