FlashAttention-2:GPU利用率从50%干到90%,同一个作者的自我进化

封面

上一篇我们聊了FlashAttention——通过减少GPU显存和缓存之间的数据搬运,让注意力机制的速度提升了3倍。那个方法很漂亮,但作者Tri Dao觉得自己还能做得更好。

一年后,他交出了FlashAttention-2。

如果说FlashAttention是"发现了一条近路",那FlashAttention-2就是"把这条路修成了高速公路"。目的地完全一样——精确的注意力计算,不做任何近似——但速度又快了2-3倍

他是怎么做到的?答案藏在一个你可能从没想过的角度:GPU的"工作量分配"问题

FlashAttention的尴尬:GPU有一半在摸鱼

FlashAttention已经很棒了——它通过分块计算避免了大量HBM读写。但如果你拿性能分析工具去看GPU的实际运行状态,会发现一个让人坐不住的事实:GPU的利用率只有30-50%

什么意思?你花大价钱买的A100,有一半的算力在那儿空转。

打个比方。你有一个工厂(GPU),工厂里有众多生产线(GPU里包含多个流多处理器和线程束)。FlashAttention的任务分配方式是:每次只让少数几条线开工,其他线等着。这批做完了,再让下一批上。

工厂一直在运转,但很多生产线在摸鱼。老板看了肯定心疼。

Image
FlashAttention-2的前向传播流程:通过分块避免实例化完整的N×N注意力矩阵,减少HBM访问(来源:原论文Figure 1)

为什么会这样?根本原因是FlashAttention的并行粒度太粗了。它只按batch和head两个维度来分配任务,而这两个维度的大小不一定能正好填满GPU上所有的线程束。就像你有很多条生产线,但每次只拿到少量订单,剩下的大部分生产线只能干瞪眼。

三招把GPU榨干:FlashAttention-2的改进

FlashAttention-2围绕"让GPU别闲着"这个核心目标,做了三个关键改进。每一个都是对GPU硬件特性的深度理解。

改进一:更好的并行分配——把序列长度也切开分。

FlashAttention只按batch和head分配任务。FlashAttention-2加了一个维度:序列长度。一个长序列可以切成好几段,分给不同的线程组并行处理。

回到工厂的比喻:原来一条生产线只能做一种产品(一个注意力头),做完了才能接下一个。现在允许一条生产线把产品拆成几段,多个工人同时做不同段。这样就没有生产线闲着了。

这个改动看似简单,但它意味着不管你的batch size多小、head数多少,只要序列够长,就能把GPU塞满。在实际训练场景中,序列长度往往是最大的维度,所以这个改动的收益特别大。

改进二:把非矩阵运算压缩到最少。

GPU被设计出来的核心目的就是做矩阵乘法。做矩阵乘法时,GPU可以跑到接近理论峰值的算力。但注意力机制里不全是矩阵乘法——还有softmax里的指数运算、归一化、dropout这些标量操作。GPU做这些事情效率低得多。

FlashAttention-2的做法是:通过巧妙的数学重排,把softmax的归一化因子等操作重新组织,让它们尽可能"搭便车"到矩阵乘法的路径上。具体来说,每个线程束只负责输出矩阵的一行,从Q的一行出发,遍历所有K和V的块,边算边更新结果。这样大部分计算都走矩阵乘法的"快车道",只有少量标量操作走"慢车道"。

Image
FlashAttention-2的并行化策略:不同线程块(thread block)处理不同的query块,每个块内的warp分工协作(来源:原论文Figure 2)

改进三:让线程少"开会",多干活。

GPU内部的线程之间需要共享中间结果。在FlashAttention中,线程之间需要频繁同步(相当于频繁"开会"),来交换softmax的中间状态。每次同步都意味着有些线程要等别的线程干完才能继续。

FlashAttention-2的做法是:给每个线程束分配独立的Q行,让它从头到尾自己做完,不需要跟其他线程束频繁沟通。就像把一个大项目拆成若干个独立的小项目,每个工程师(线程束)负责一个,做完最后汇总就行,不用每天开站会。

三招加在一起,效果就是GPU的利用率从30-50%提升到了最高73%。

数据说话:到底快了多少?

论文里的数据很硬核:

**相比PyTorch标准注意力实现:**根据推算综合提升约4-8倍(FlashAttention的2-4倍再叠加约2倍提速)。

相比FlashAttention v1:又快了约2-3倍。注意,v1已经比标准实现快3倍了,所以这是"快的上面更快"。

绝对性能:在A100上,FlashAttention-2的端到端训练GPT模型时,整体训练速度达到了225 TFLOPs/s,接近A100理论峰值312 TFLOPs/s的72%。在H100上更是达到了约575 TFLOPs/s。要知道H100的FP16/BF16理论峰值也就约990 TFLOPs/s,这个利用率已经相当惊人了。

Image
FlashAttention-2在A100上的前向+反向传播速度,最高达到225 TFLOPs/s,接近理论峰值利用率(来源:原论文Figure 3)

**端到端训练速度:**在训练GPT-3规模的模型时,FlashAttention-2在端到端训练中也能带来显著的加速体验。注意力只是训练的一部分(还有前向传播、反向传播、优化器更新等),所以端到端的提升没有纯注意力计算那么大,但仍然非常可观。

**再次强调:**FlashAttention-2的数学输出和标准注意力一模一样。不是近似,不是trade精度换速度,是完完全全一样的结果。纯粹靠更聪明的实现方式拿到了2-3倍的加速。

为什么这件事值得单独写一篇论文?

你可能会想:“不就是优化了一下任务分配嘛,至于发一篇论文?”

答案是:这个优化的经济价值太大了。

训练一个大模型动辄几百万到几千万美元。注意力机制是Transformer训练中计算量最大的部分之一。如果这部分能快2倍,意味着训练成本可以省下几十个百分点。对于每天都在跑训练和推理的AI公司来说,这是真金白银。

更重要的是,FlashAttention-2让更长的上下文窗口变得实际可行。在v1的基础上再快2-3倍,意味着128K甚至更长上下文的训练不再是天方夜谭。长上下文直接关系到模型的能力边界——能记住更多的上下文,就能做更复杂的推理、处理更长的文档。

这也是为什么Mistral、Meta(Llama系列)、Google(Gemini)等公司都在第一时间把FlashAttention-2集成到了自己的训练框架里。从v1到v2,不是锦上添花,是雪中送炭。

作为工程师,我从这篇论文里得到的一个重要认知是:**“利用率"是一个被严重低估的指标。**不只是GPU,我们日常做的很多系统优化——数据库连接池、线程池、缓存命中率——本质上都是在提高利用率。当你觉得某个东西"已经够快了”,问问自己:理论极限是多少?现在的实际利用率是多少?差距就是优化空间。

另一个收获是:好的工作值得持续打磨。FlashAttention v1已经是行业突破了,但Tri Dao没有止步。他深入到GPU硬件层面,逐个分析warps在干什么、哪些时间花在了等待上、哪些操作走了慢路径,然后把每一个瓶颈都优化掉。这种"追求极致"的工程思维,比论文本身更值得学习。

论文链接:https://arxiv.org/abs/2307.08691

kk的大模型论文学习笔记 · 第9篇 · FlashAttention-2

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值