TPU7x (Ironwood) 效能最佳化

本指南說明如何有效管理多層記憶體系統之間的資料移動,藉此運用 TPU7x (Ironwood) 最佳化效能。包括低精確度訓練、分片、通訊最佳化、啟用重新實體化、範圍虛擬記憶體調整,以及自訂加速器核心等技術。

如要使用 TPU7x 提升效能,請先熟悉 Ironwood 架構,特別是記憶體階層和互連拓撲。詳情請參閱「TPU7x (Ironwood)」。

使用 FP8 進行低精確度訓練

FP8 (8 位元浮點數) 是一種有效率的數值型資料格式,主要用於加速模型訓練和推論。TPU 使用 8 位元 (而非標準的 16 位元格式 (FP16 或 BF16) 和 32 位元 (FP32)) 表示數字,因此處理資料的速度大幅提升,且使用的記憶體較少。

TPU7x 支援 FP8 資料類型的內建硬體加速功能,每個晶片的理論峰值效能為 4614 TFLOPS。這項功能可大幅縮短端對端訓練時間。對於相容的作業,尤其是 AI 工作負載常見的密集矩陣乘法,使用 FP8 可比標準 BF16 訓練提升 1.3 倍的效能。與 BF16 相比,FP8 的尖峰 FLOP 數是 BF16 的兩倍,權重和啟動的記憶體用量則只有 BF16 的一半。對於受記憶體容量或頻寬限制的運算量受限工作負載和情境,FP8 應是主要的調整槓桿。

使用 FP8 可享有下列成效優勢:

  • 減少高頻寬記憶體 (HBM) 壓力:記憶體用量減少後,較大的模型或推論期間具有較大 KV 快取的模型,就能完全納入 192 GB 的 HBM。這樣可避免卸載至速度較慢的主機記憶體,節省成本。
  • 有效批量大小增加:FP8 可減少啟動所需的記憶體,因此能使用較大的批量大小。這有助於提升資料平行處理能力,進而提高處理量,並更有效運用運算單元。
  • 降低記憶體頻寬需求:每次運算移動一半的資料量,可減少 HBM 到 MXU 資料路徑的需求。在資料移動是常見瓶頸的系統上,這有助於讓 MXU 保持工作飽和狀態。

如要使用 FP8,同時確保效能不會降低或只會略微降低,就必須謹慎選擇量化技術。以下是 FP8 訓練的最佳做法:

  • 縮放精細度:以每個張量的縮放比例做為基準。如有品質或效能問題,請改用依軸縮放。 子聲道縮放可能不需要。
  • 縮放模式:動態縮放會在執行階段計算縮放比例係數,是維持品質的理想預設模式。雖然靜態縮放功能可消除運算作業,大幅提升效能,但您必須仔細分析,才能判斷正確的縮放比例,而且可能不適合所有用途,尤其是模型設定變更時。反之,部分強大的模型和設定可將權重或啟動的比例修正為 FP8 限制,讓您減少量化負擔,同時維持準確度並提升效能。
  • FP8 格式 (E4M3 和 E5M2):常見且有效的方法是混合使用 FP8 格式。舉例來說,您可以在前向傳遞中使用 E4M3 做為權重和啟動,藉此利用 E4M3 的更高精確度,並在後向傳遞中使用 E5M2 做為梯度,以適應梯度的更廣動態範圍。
  • 四捨五入:對漸層使用「四捨五入至最接近的偶數」(RNE) 而非隨機四捨五入,可維持品質,同時提供更優異的效能和可重現性。
  • 在 MaxText 中啟用 FP8MaxText 支援透過 QWIX 量化程式庫進行 FP8 訓練。如要啟用量化功能,請在設定中設定下列標記:use_qwix_quantization=true

分片與平行處理

分片是將大型模型或訓練資料切成較小的片段,並分配到多個 TPU 晶片或核心的程序。選擇合適的分片策略,是 TPU7x 達到高效能的重要因素。

如果只是單純地盡量提高平行處理程度,通常會因為受到通訊限制而導致效能不佳。通常,最佳做法是選取最簡單的切分策略,以符合記憶體限制,因為這樣可盡量減少通訊負擔,並有效運用運算單元。

選取分片策略前,任何效能調整作業的第一步都應該是算術強度分析。這項分析會判斷特定運算是否受到運算、記憶體頻寬或互連頻寬限制。計算方式為浮點運算與必須移動的資料位元組數的比率。

算術密集度高表示工作負載受運算能力限制。算術強度偏低表示工作負載受記憶體或通訊限制,效能會受到資料從 HBM 或 ICI 網路移轉的速度限制。這項分析可做為理想批量大小和分片策略的參考依據。舉例來說,如果工作負載受通訊限制,那麼引入更多通訊的策略 (例如高程度的張量平行處理) 就不會帶來好處。

分片策略決策架構

MaxText 提供多種分片策略。最佳選擇取決於模型架構、序列長度,以及是否需要平衡運算負載與通訊負擔。

  • 完全分片資料平行處理 (FSDP):這是資料平行處理的偏好預設策略。FSDP 會在資料平行裝置之間,將模型權重、梯度和最佳化工具狀態分片。在運算期間,每個裝置都會執行 All-Gather 作業,以擷取本機微批次所需的完整權重。只要每個裝置的批量夠大,足以隱藏這項 All-Gather 通訊的延遲,FSDP 就非常有效。對於專家混合 (MoE) 模型,算術強度計算必須考量稀疏性。
  • 張量平行化 (TP):TP 會在裝置間分割個別張量。通常張量是多層感知器 (MLP) 和注意力區塊中的權重矩陣。硬體的算術密集度高達 11.5k,因此模型維度必須非常高,才能透過 ICI 實現 TP,否則系統可能會受到通訊限制。
  • 專家平行化 (EP):這是訓練 MoE 模型時的標準必要策略。EP 會將「專家」層分散到一組裝置上,並使用 All-to-All 通訊集合,將權杖路由至指定的專家裝置。如果模型的 MLP 維度夠大,可接近屋頂線,EP 就能發揮效率。
  • 內容平行處理 (CP):CP 是一種專門策略,對於訓練序列長度極長的模型至關重要。主要功能是管理啟動的記憶體用量,這會隨著序列長度呈二次方成長,並可能超過 HBM 容量。CP 會將啟動張量的序列維度分片,因此可以使用每個裝置的分數批量大小。由於 CP 比 FSDP 引入更多通訊,一般規則是使用滿足記憶體限制的最低 CP 程度,並確保批次軸分片保持為整數。

下表列出常見工作負載類型與最佳分片策略的對應關係:

工作負載類型 建議的主要分片 次要分片 主要瓶頸 理由
密集模型 - 短序列 FSDP 不適用 Rematerialization、FF Matmuls FSDP 則可提供最佳平衡。如果是短序列,啟動記憶體可能不是主要問題。關鍵在於要有足夠大的全域批次,才能隱藏 FSDP 的權重 All-Gather。隨著批量大小增加,啟用大小也會增加,因此需要適當的重新實體化政策,確保這項設定不會耗盡記憶體。
密集模型 - 長序列 FSDP CP 快閃注意力、啟動記憶體 啟動記憶體會成為主要限制。啟用每個裝置的分數批次大小,並避免記憶體不足 (OOM) 問題。Flash attention 是運算和浪費時間的主要來源。
MoE 模型 - 短序列 FSDP + EP 不適用 All-to-All (專家轉送)、重新實體化 MoE 模型需要 EP 來分片專家。權杖路徑的全對全通訊是主要瓶頸,必須重疊。重新實體化也是造成浪費的重要原因。
MoE 模型 - 超大規模 FSDP + EP + PP 模型平行處理 (MP) 先前提及的所有瓶頸,以及管道泡泡 如果模型超出單一 Pod 的記憶體容量,就必須使用 PP 將層級分片到各個 Pod。這會導致 DCN 通訊和管道泡泡的額外負荷。這項設定非常複雜,需要仔細調整。

通訊最佳化

在 TPU7x 上重疊通訊和運算的主要機制稱為「SparseCore Collective Offloading」。Ironwood 架構包含專屬的 SparseCore 裝置,可做為獨立的控制執行緒,負責管理 ICI 結構的資料移動作業。這樣一來,集體通訊作業 (例如 All-Gather 或 Reduce-Scatter) 就能與 TensorCore 上執行的主要運算作業並行執行。這是 TPU7x 上非同步集合的建議方法。使用建議的旗標,為最常見的集合啟用卸載功能。

重新實現啟用

啟用重新實體化 (又稱梯度檢查點) 是減少模型 HBM 占用空間的基本技術。這項技術不會將正向傳遞中的所有中繼啟動值儲存在 HBM 中,以供反向傳遞使用,而是只儲存幾個重要啟動值 (檢查點),並在反向傳遞期間視需要重新計算其他啟動值。這麼做可節省大量記憶體,但會增加運算量 (標準 Transformer 區塊約增加 25% 至 30% 的 FLOP)。

如何積極套用重新實體化是重要的調整參數,完全取決於主要瓶頸,而這通常會因序列長度而異。

長序列工作負載 (例如 128k):在這些情況下,啟動張量的大小是 HBM 的主要消耗者。工作負載通常會受到記憶體限制。因此,套用積極的重新實體化政策非常有益。節省記憶體可讓訓練繼續進行,不會發生記憶體不足錯誤,還能使用較大的批次大小,重新計算的運算負荷是值得的取捨。

短序列工作負載 (例如 8k):在這些情況下,啟動記憶體較不令人擔心,工作負載也較有可能受到運算限制。重新實現的運算負擔可能是效率不彰的最大來源。

在 MaxText 中調整重新實體化政策

MaxText 提供一組預設和自訂政策,可透過 remat_policy 旗標設定,精細控制重新實現作業。

預設政策

MaxText 提供下列內建政策:

  • full:最積極的政策,幾乎所有內容都會重新實體化。這會盡量減少 HBM 用量,但會盡量增加重新運算的負擔。非常適合記憶體極度受限的長序列情境。
  • minimal:最不積極的政策,會儲存最多啟用次數。這項做法可充分運用 HBM,但會盡量減少重新運算。最適合短序列、受運算限制的工作負載,且不需擔心記憶體問題。
  • 中繼政策save_dot_with_context_except_mlpsave_qkv_projsave_out_proj 等選項會選擇性檢查昂貴點積運算的輸出內容,同時重新實現較便宜的元素運算,藉此提供各種取捨方案。

自訂政策

如要進一步控管,可以將 remat_policy 設為 custom。這可讓您指定模型解碼模組中個別層的行為。每個圖層都可以指派下列三種行為之一:

  • device:啟用狀態會儲存在 TPU 裝置的 HBM 中。
  • remat:系統會捨棄啟用值,並在反向傳遞期間重新實現。
  • offload:啟用作業會從 HBM 移至 CPU 主機的記憶體,以 PCIe 傳輸延遲為代價,釋放 HBM。

範圍 VMEM 調整

核心效能 (例如快閃注意力機制) 取決於核心中選取的圖塊大小,而圖塊大小會受到可用向量記憶體 (VMEM) 的限制。TPU7x 晶片中的每個 TensorCore 都有 64 MiB 的向量記憶體 (VMEM)。這個 VMEM 容量可分配給目前範圍 (範圍 VMEM) 和未來權重預先擷取。增加範圍 VMEM 可在核心中增加圖塊大小,進而減少記憶體停滯,並提升核心效能。您可以設定 xla_tpu_scoped_vmem_limit_kib (位於 LIBTPU_INIT_ARGS 中),藉此變更範圍內的 VMEM 大小,用於探索核心效能和端對端效能限制。最佳化範圍 VMEM 大小可間接影響自訂 Pallas 核心效能,因為增加範圍 VMEM 會為核心內圖塊大小解鎖更大的超參數搜尋空間。

Tokamax 核心

Tokamax 是高效能的 JAX 核心程式庫,內含許多經過高度最佳化的 TPU 核心,可解決多個常見的硬體專屬效能瓶頸:

  • Splash 注意力:Splash 注意力是主要注意力實作方式,可消除標準注意力的 HBM 瓶頸,並在 TPU 上使用最有效率的注意力實作方式。
  • Megablox 分組矩陣乘法 (GMM):對於 MoE 工作負載,Megablox 會透過計算不規則的啟動表示法,有效處理分組矩陣乘法。這項技術可有效對應參差不齊的維度,計算 LHS 中參差不齊的資料列群組與對應專家矩陣之間的矩陣乘法,不必將批次填補至固定大小。
  • 使用 tune-jax 進行實證調整tune-jax 程式庫提供公用程式,可執行實證搜尋,找出最佳區塊大小。預設核心大小通常不是最佳選擇;調整大小可選擇適合硬體的 VMEM 圖塊大小,盡量提高硬體使用率。
  • 最大 logits 估計值:您可以為 max_logit_const 設定值,進一步最佳化 Tokamax Splash 注意力核心。如果設定此值,系統會在注意力機制 (softmax(Q * KT)) 的 softmax 運算期間,取代最大邏輯值的縮減計算,減少部分運算和同步處理的負擔。在 MaxText 中,這項功能是由 config use_max_logits_estimate 實作,可設為 None (停用) 或浮點值。確認特定模型的 Logit 範圍與估算值相容,以免發生數值溢位。如果設定這個值,建議進行收斂測試。