TPU7x (Ironwood) 性能优化

本指南介绍了多种方法,可用于通过高效管理 TPU7x (Ironwood) 多层级内存系统之间的数据移动来优化性能。这包括低精度训练、分片、通信优化、激活重新实物化、范围限定的虚拟内存调优和自定义加速器内核等技术。

如需使用 TPU7x 优化性能,您必须先熟悉 Ironwood 架构,尤其是内存层次结构和互连拓扑。如需了解详情,请参阅 TPU7x (Ironwood)

使用 FP8 进行低精度训练

FP8(8 位浮点)是一种高效的数值数据格式,主要用于加速模型训练和推理。通过使用 8 位(而非标准的 16 位格式 [FP16 或 BF16] 和 32 位 [FP32])表示数字,TPU 可以显著加快数据处理速度并减少内存用量。

TPU7x 支持内置的 FP8 数据类型硬件加速,每个芯片的理论峰值性能为 4614 TFLOPS。此功能可显著缩短端到端训练时间。对于兼容的操作,尤其是 AI 工作负载中常见的密集矩阵乘法,使用 FP8 可将性能提升到标准 BF16 训练的 1.3 倍。与 BF16 相比,FP8 的峰值 FLOPs 翻了一番,权重和激活的内存占用减少了一半。对于受内存容量或带宽限制的计算密集型工作负载和场景,FP8 应该是主要的调优手段。

使用 FP8 可带来以下性能优势:

  • 降低高带宽内存 (HBM) 压力:更小的内存占用空间可让更大的模型或在推理期间具有更大 KV 缓存的模型完全适合 192 GB 的 HBM。这样可以避免昂贵的卸载到较慢的主机内存。
  • 提高有效批次大小:通过减少激活所需的内存,FP8 可实现更大的批次大小。这有助于提高数据并行性,并可提高吞吐量和计算单元的利用率。
  • 降低内存带宽要求:每次操作移动一半的数据量可降低对 HBM 到 MXU 数据路径的需求。在数据移动是常见瓶颈的系统中,这有助于使 MXU 始终处于饱和工作状态。

使用 FP8 时,如果性能下降为零或有限,则需要仔细选择量化技术。以下是一些可考虑用于 FP8 训练的最佳实践:

  • 缩放粒度:以每个张量的伸缩作为基准。如果存在质量或性能问题,请切换到按轴伸缩。 子频道伸缩可能是不必要的。
  • 缩放模式:动态缩放会在运行时计算伸缩比例,是保持质量的理想默认设置。虽然静态伸缩可以通过消除计算来显著提升性能,但它需要仔细的分析来确定正确的伸缩比例,并且可能不适合所有使用情形,尤其是在模型配置发生变化时。相反,一些稳健的模型和配置可以将权重或激活的缩放比例固定为 FP8 上限,从而减少量化开销,同时保持准确性并提升性能。
  • FP8 格式(E4M3 和 E5M2):一种常见且有效的方法是混合使用 FP8 格式。例如,在正向传递中使用 E4M3 来表示权重和激活,以利用 E4M3 的更高精度;在反向传递中使用 E5M2 来表示梯度,以适应梯度更广的动态范围。
  • 舍入:使用“四舍六入五成双”舍入 (RNE) 而不是随机舍入来处理梯度,可以在保持质量的同时提供更好的性能和可重现性。
  • 在 MaxText 中启用 FP8MaxText 通过 QWIX 量化库支持 FP8 训练。如需激活量化,请在配置中设置以下标志:use_qwix_quantization=true

分片和并行性

分片是指将大型模型或其训练数据切分为更小的部分,并将其分配到多个 TPU 芯片或核心的过程。 选择合适的分片策略对于在 TPU7x 上实现高性能至关重要。

一种单纯地最大限度提高并行度的简单方法通常会因受通信限制而导致性能不佳。最佳方法通常是选择满足内存限制的最简单分片策略,因为这样可以最大限度地减少通信开销,并高效利用计算单元。

在选择分片策略之前,任何性能调优工作的第一步都应该是算术强度分析。此分析可确定给定的计算是否受计算、内存带宽或互连带宽的限制。计算方式为浮点运算次数与必须移动的数据字节数的比率。

较高的算术强度表示工作负载受计算限制。较低的算术强度表明工作负载受内存或通信限制,性能受数据从 HBM 或通过 ICI 网络移动的速度限制。此分析可为确定理想的批次大小和分片策略提供依据。例如,对于通信密集型工作负载,引入更多通信的分片策略(例如高程度的张量并行处理)不会带来任何好处。

分片策略决策框架

MaxText 提供多种分片策略。最佳选择取决于模型架构、序列长度,以及是否需要平衡计算负载与通信开销。

  • 完全分片数据并行处理 (FSDP):这是数据并行处理的首选默认策略。FSDP 会跨数据并行设备对模型权重、梯度和优化器状态进行分片。在计算过程中,每个设备都会执行一次 All-Gather 操作,以检索其本地微批次所需的完整权重。只要每个设备的批次大小足够大,能够隐藏此 All-Gather 通信的延迟,FSDP 就非常有效。对于混合专家 (MoE) 模型,算术强度计算必须考虑稀疏性。
  • 张量并行处理 (TP):TP 会跨设备对各个张量进行分片。通常,张量是多层感知器 (MLP) 和注意力块中的权重矩阵。硬件的高算术强度 (11.5k) 对模型维度提出了非常高的要求,以使 TP 在 ICI 上可行,尝试使用 TP 可能会导致系统受通信限制。
  • 专家并行 (EP):这是训练 MoE 模型的标准且必要的策略。EP 会将“专家”层分片到一组设备中,并使用全到全通信集合将令牌路由到其指定的专家设备。如果模型的 MLP 维度足够大,接近屋顶线,EP 就能高效运行。
  • 上下文并行 (CP):CP 是一种专门的策略,对于训练序列长度非常长的模型至关重要。其主要功能是管理激活的内存消耗,该内存消耗随序列长度呈二次方增长,并可能超过 HBM 容量。CP 会对激活张量的序列维度进行分片,从而允许使用部分每设备批次大小。由于 CP 引入的通信量比 FSDP 多,因此一般规则是使用满足内存限制并确保批次轴分片保持为整数所需的最小 CP 程度。

下表将常见的工作负载类型映射到最佳分片策略:

工作负载类型 建议的主分片 次要分片 主要瓶颈 原因
密集模型 - 短序列 FSDP 不适用 重新具体化,FF Matmuls FSDP 可提供最佳平衡。对于短序列,激活内存可能不是主要问题。关键在于全局批次足够大,可以隐藏 FSDP 的权重 All-Gather。随着批次大小的增加,激活大小也会增加,因此需要采用合适的重新实物化政策,以确保此配置不会耗尽内存。
密集模型 - 长序列 FSDP CP Flash Attention、激活内存 激活内存成为主要限制因素。CP 是启用每设备分数批次大小并避免内存不足 (OOM) 问题的必要条件。Flash Attention 是计算和浪费时间的主要来源。
MoE 模型 - 短序列 FSDP + EP 不适用 全到全(专家路由),重新实现 MoE 模型需要 EP 来对专家进行分片。用于令牌路由的全到全通信是一个必须重叠的主要瓶颈。重新物化也是一个重要的浪费来源。
MoE 模型 - 超大规模 FSDP + EP + PP 模型并行处理 (MP) 之前提及的所有瓶颈,以及流水线气泡 对于超出单个 Pod 内存的模型,需要 PP 来跨 Pod 对层进行分片。这会引入 DCN 通信和流水线气泡开销。这是一种非常复杂的配置,需要仔细调优。

沟通优化

在 TPU7x 上重叠通信和计算的主要机制称为 SparseCore 集体分流。Ironwood 架构包含专用 SparseCore 单元,这些单元充当独立的控制线程,能够管理 ICI 结构上的数据移动。这样一来,集体通信操作(如 All-Gather 或 Reduce-Scatter)就可以与 TensorCore 上进行的主要计算并行执行。这是建议用于 TPU7x 上异步集合的方法。使用推荐的标志,以针对最常见的集合启用分流。

激活重新序列化

激活重新实体化(也称为梯度检查点)是一种用于减少模型 HBM 占用空间的基本技术。它不会将前向传播中的所有中间激活都存储在 HBM 中以供反向传播期间使用,而是仅保存一些关键激活(检查点),并在反向传播期间按需重新计算其他激活。这样可以节省大量内存,但会增加计算量(对于标准 Transformer 块,大约会增加 25-30% 的 FLOP)。

如何积极地应用重新实物化是一项关键的调优参数,完全取决于主要瓶颈,而主要瓶颈通常会随序列长度而变化。

对于长序列工作负载(例如 128k):在这些情况下,激活张量的大小是 HBM 的主要消耗者。工作负载通常受内存限制。因此,应用激进的重新物化政策非常有益。节省的内存可确保训练顺利进行,不会出现内存不足错误,还可实现更大的批次大小,而重新计算的计算开销是值得的。

对于短序列工作负载(例如 8k):在这些情况下,激活内存的顾虑要少得多,工作负载更有可能受计算限制。重新实物化的计算开销可能是效率低下的最大来源。

在 MaxText 中调整重新实物化政策

MaxText 通过一组预设和自定义政策(使用 remat_policy 标志配置)提供对重新实物化的精细控制。

预设政策

MaxText 提供以下内置政策:

  • full:最激进的政策,几乎重新实现所有内容。 这会最大限度地减少 HBM 使用量,但会最大限度地增加重新计算开销。非常适合内存极其有限的长序列场景。
  • minimal:最不激进的政策,可存储大多数激活。这样可以最大限度地利用 HBM,同时最大限度地减少重新计算。最适合短序列、计算密集型工作负载,且无需考虑内存。
  • 中级政策save_dot_with_context_except_mlpsave_qkv_projsave_out_proj 等选项通过选择性地对开销大的点积运算的输出进行检查点设置,同时重新实现开销较小的逐元素运算,从而提供各种权衡方案。

自定义政策

如需更精细的控制,您可以将 remat_policy 设置为 custom。这样一来,您就可以指定模型解码模块中各个层的行为。每个层都可以分配以下三种行为之一:

  • device:激活存储在 TPU 设备的 HBM 中。
  • remat:系统会舍弃激活,并在反向传播期间重新实现。
  • offload:激活从 HBM 移至 CPU 主机的内存,以 PCIe 传输延迟为代价释放 HBM。

范围限定的 VMEM 调整

内核性能(例如 Flash Attention)取决于内核中选择的 tile 大小,而 tile 大小受可用向量内存 (VMEM) 的限制。TPU7x 芯片中的每个 TensorCore 都有 64 MiB 的矢量内存 (VMEM)。此 VMEM 容量可在当前范围(有范围的 VMEM)和未来的权重预取之间进行拆分。增加范围限定的 VMEM 可在内核中增加 tile 大小,从而可能减少内存停滞并提高内核的性能。您可以通过设置 xla_tpu_scoped_vmem_limit_kib(在 LIBTPU_INIT_ARGS 中)来更改有范围的 VMEM 大小,这可用于探索内核性能以及端到端性能限制。优化范围限定的 VMEM 大小可能会间接影响自定义 Pallas 内核性能,因为增加范围限定的 VMEM 会为内核中的 tile 大小解锁更大的超参数搜索空间。

Tokamax 内核

Tokamax 是一个高性能 JAX 内核库,包含许多高度优化的 TPU 内核,可解决几个常见的硬件特定瓶颈问题:

  • Splash attention:Splash attention 用作主要的注意力机制实现,以消除标准注意力机制的 HBM 瓶颈,并在 TPU 上使用最有效的注意力机制实现。
  • Megablox 分组矩阵乘法 (GMM):对于 MoE 工作负载,Megablox 通过计算不规则激活表示法来高效处理分组矩阵乘法。它可以高效地映射到不规则维度,计算 LHS 中不规则的行组与相应专家矩阵之间的矩阵乘法,从而避免将批次填充到固定大小。
  • 使用 tune-jax 进行实证调优tune-jax 库包含一些实用程序,可用于实证搜索最佳块大小。默认内核大小通常不是最佳选择;通过调整,可以选择硬件友好的 VMEM tile 大小,从而最大限度地提高硬件利用率。
  • 最大 logits 估计值:通过为 max_logit_const 设置值,可以进一步优化 Tokamax Splash 注意力内核。如果设置,它会替换注意力机制(softmax(Q * KT))的 softmax 运算期间的最大 logit 的缩减计算,从而减少一些计算和同步开销。在 MaxText 中,它是通过配置 use_max_logits_estimate 实现的,可以将其设置为 None(停用)或浮点值。验证特定模型的 logit 范围是否仍与估计值兼容,以防止出现数值溢出。 如果设置了此值,建议进行收敛测试。