TPU7x(Ironwood)のパフォーマンスの最適化
このガイドでは、TPU7x(Ironwood)でパフォーマンスを最適化するためのいくつかの方法について説明します。これらの方法では、多層メモリ システム間のデータ移動を効率的に管理します。これには、低精度トレーニング、シャーディング、通信の最適化、アクティベーションの再マテリアライズ、スコープ付き仮想メモリ チューニング、カスタム アクセラレータ カーネルなどの手法が含まれます。
TPU7x でパフォーマンスを最適化するには、まず Ironwood アーキテクチャ、特にメモリ階層とインターコネクト トポロジを理解する必要があります。詳細については、TPU7x(Ironwood)をご覧ください。
FP8 を使用した低精度トレーニング
FP8(8 ビット浮動小数点)は、主にモデルのトレーニングと推論を高速化するために使用される効率的な数値データ形式です。TPU は、標準の 16 ビット形式(FP16 または BF16)と 32 ビット(FP32)ではなく、8 ビットを使用して数値を表現することで、データを大幅に高速に処理し、メモリ使用量を削減できます。
TPU7x は、FP8 データ型の組み込みハードウェア アクセラレーションをサポートしており、チップあたり 4, 614 TFLOPS の理論上のピーク パフォーマンスを実現します。この機能により、エンドツーエンドのトレーニング時間を大幅に短縮できます。互換性のあるオペレーション、特に AI ワークロードで一般的な高密度行列乗算では、FP8 を使用すると、標準の BF16 トレーニングよりも 1.3 倍のパフォーマンス向上が得られます。BF16 と比較して、FP8 はピーク FLOP を 2 倍にし、重みとアクティベーションのメモリ使用量を半分にします。FP8 は、コンピューティング バウンドのワークロードと、メモリ容量または帯域幅によって制約されるシナリオの両方で、主要なチューニング レバーとなります。
FP8 を使用すると、次のようなパフォーマンス上のメリットがあります。
- 高帯域幅メモリ(HBM)のプレッシャーの軽減: メモリ使用量が小さいため、推論時に大きなモデルや大きな KV キャッシュを持つモデルを 192 GB の HBM に完全に収めることができます。これにより、低速のホストメモリへの高コストなオフロードを回避できます。
- 有効なバッチサイズの増加: アクティベーションに必要なメモリが削減されるため、FP8 ではバッチサイズを大きくできます。これにより、データ並列処理が改善され、スループットの向上とコンピューティング ユニットの使用率の向上につながります。
- メモリ帯域幅の要件の低減: オペレーションごとに移動するデータ量を半分にすることで、HBM から MXU へのデータパスの負荷が軽減されます。データ移動が一般的なボトルネックとなるシステムでは、MXU の作業負荷を飽和状態に保つことができます。
FP8 を使用してパフォーマンスの低下をゼロまたは最小限に抑えるには、量子化手法を慎重に選択する必要があります。FP8 トレーニングで考慮すべきベスト プラクティスをいくつかご紹介します。
- スケーリングの粒度: ベースラインとしてテンソルごとのスケーリングから始めます。品質やパフォーマンスに問題がある場合は、軸ごとのスケーリングに切り替えます。 サブチャネルのスケーリングは不要な場合があります。
- スケーリング モード: ランタイムでスケーリング ファクタを計算する動的スケーリングは、品質を維持するための適切なデフォルトです。静的スケーリングでは、計算を排除することでパフォーマンスを大幅に向上させることができますが、正しいスケーリング ファクタを決定するには慎重なプロファイリングが必要であり、特にモデル構成が変更される場合は、すべてのユースケースに適しているとは限りません。一方、堅牢なモデルと構成では、重みまたはアクティベーションの FP8 上限にスケールを固定できるため、精度を維持しながら量子化のオーバーヘッドを削減し、パフォーマンスを向上させることができます。
- FP8 形式(E4M3 と E5M2): 一般的で効果的なアプローチは、FP8 形式を 組み合わせて使用することです。たとえば、フォワード パスでは重みとアクティベーションに E4M3 を使用して E4M3 の高精度を活用し、バックワード パスでは勾配に E5M2 を使用して勾配の広いダイナミック レンジに対応します。
- 丸め: 勾配に確率的 丸めではなく「最近傍偶数への丸め」(RNE)を使用すると、品質を維持しながらパフォーマンスと再現性を向上させることができます。
- MaxText で FP8 を有効にする:
MaxText は、QWIX 量子化ライブラリを介して FP8 トレーニングをサポートしています。量子化を有効にするには、構成で
use_qwix_quantization=trueフラグを設定します。
シャーディングと並列処理
シャーディングとは、大規模なモデルまたはそのトレーニング データを小さなピースに分割し、複数の TPU チップまたはコアに分散するプロセスです。 TPU7x で高いパフォーマンスを実現するには、適切なシャーディング戦略を選択することが重要です。
並列処理の度合いを純粋に最大化するナイーブなアプローチでは、通信バウンドになることでパフォーマンスが低下することがよくあります。最適なアプローチは、メモリ制約を満たす最もシンプルなシャーディング戦略を選択することです。これにより、通信のオーバーヘッドが最小限に抑えられ、コンピューティング ユニットを効率的に使用できます。
シャーディング戦略を選択する前に、パフォーマンス チューニングの最初のステップとして、演算密度の分析を行う必要があります。この分析では、特定の計算がコンピューティング、メモリ帯域幅、インターコネクト帯域幅のいずれによって制限されているかを判断します。これは、浮動小数点演算の比率を移動する必要があるデータのバイト数で割った値として計算されます。
演算密度が高い場合は、コンピューティング バウンドのワークロードを示します。演算密度が低い場合は、メモリまたは通信バウンドのワークロードを示します。この場合、パフォーマンスは HBM から ICI ネットワークへのデータ移動速度によって制限されます。この分析により、最適なバッチサイズとシャーディング戦略を把握できます。たとえば、通信バウンドのワークロードでは、高度なテンソル並列処理など、通信をさらに増やすシャーディング戦略はメリットがありません。
シャーディング戦略の決定フレームワーク
MaxText には、さまざまなシャーディング戦略が用意されています。最適な選択は、モデル アーキテクチャ、シーケンス長、計算負荷と通信オーバーヘッドのバランスを取る必要性によって異なります。
- 完全にシャーディングされたデータ並列処理(FSDP): これは、データ並列処理の推奨されるデフォルト 戦略です。FSDP は、モデルの重み、勾配、オプティマイザの状態をデータ並列デバイス間でシャーディングします。計算中、各デバイスは All-Gather オペレーションを実行して、ローカル マイクロバッチに必要な完全な重みを取得します。デバイスごとのバッチサイズが、この All-Gather 通信のレイテンシを隠すのに十分な大きさであれば、FSDP は非常に効果的です。Mixture-of-Experts(MoE)モデルの場合、演算密度の計算ではスパース性を考慮する必要があります。
- テンソル並列処理(TP): TP は、個々のテンソルをデバイス間でシャーディングします。 通常、テンソルは多層パーセプトロン(MLP)とアテンション ブロックの重み行列です。ハードウェアの演算強度が高い(11.5k)ため、ICI で TP を実現するには、モデルのディメンションに非常に高い要件が課せられます。TP を使用しようとすると、システムが通信バウンドになる可能性があります。
- エキスパート並列処理(EP): これは、 MoE モデルをトレーニングするための標準的で必要な戦略です。EP は、一連のデバイス間で「エキスパート」レイヤをシャーディングし、All-to-All 通信コレクティブを使用してトークンを指定されたエキスパート デバイスにルーティングします。モデルの MLP ディメンションがルーフラインに近づくほど、EP は効率的になります。
- コンテキスト並列処理(CP): CP は、シーケンス長が非常に長いモデルをトレーニングするために不可欠な 特殊な戦略です。主な機能は、アクティベーションのメモリ使用量を管理することです。アクティベーションのメモリ使用量はシーケンス長とともに二次関数的に増加し、HBM 容量を超える可能性があります。CP は、アクティベーション テンソルのシーケンス ディメンションをシャーディングするため、デバイスごとのバッチサイズを分数で使用できます。CP は FSDP よりも多くの通信を導入するため、一般的なルールとして、メモリ制約を満たすために必要な最小限の CP を使用し、バッチ軸のシャードが整数であることを確認します。
次の表に、一般的なワークロード タイプと最適なシャーディング戦略のマッピングを示します。
| ワークロード タイプ | 推奨されるプライマリ シャーディング | セカンダリ シャーディング | 主なボトルネック | 根拠 |
|---|---|---|---|---|
| 高密度モデル - 短いシーケンス | FSDP | なし | 再マテリアライズ、FF Matmuls | FSDP は最適なバランスを提供します。シーケンスが短い場合、アクティベーション メモリは大きな問題ではない可能性があります。重要なのは、FSDP の重み All-Gather を隠すのに十分なグローバル バッチ です。バッチサイズが増加すると、アクティベーション サイズが増加します。この構成でメモリ不足にならないようにするには、適切な再マテリアライズ ポリシーが必要です 。 |
| 高密度モデル - 長いシーケンス | FSDP | CP | Flash アテンション、アクティベーション メモリ | アクティベーション メモリが主な制約になります。デバイスごとのバッチサイズを分数で有効にし、メモリ不足(OOM)の問題を回避するには、CP が必要です。Flash アテンションは、コンピューティングと無駄な時間の主な原因です。 |
| MoE モデル - 短いシーケンス | FSDP + EP | なし | All-to-All(エキスパート ルーティング)、再マテリアライズ | MoE モデルでは、エキスパートをシャーディングするために EP が必要です。トークン ルーティングの All-to-All 通信は、オーバーラップする必要がある 大きなボトルネックです。再マテリアライズも無駄の大きな原因です。 |
| MoE モデル - 非常に大規模 | FSDP + EP + PP | モデル並列処理(MP) | 前述のボトルネックに加えて、パイプライン バブル | 単一 Pod のメモリを超えるモデルの場合、Pod 間でレイヤをシャーディングするには PP が必要です。これにより、DCN 通信とパイプライン バブルのオーバーヘッドが発生します。これは非常に複雑な構成であり、慎重な チューニングが必要です。 |
通信の最適化
TPU7x で通信とコンピューティングをオーバーラップする主なメカニズムは、SparseCore Collective Offloading と呼ばれます。 Ironwood アーキテクチャには、専用の SparseCore ユニットが含まれています。これは、ICI ファブリック上のデータ移動を管理できる独立した制御スレッドとして機能します。これにより、All-Gather や Reduce-Scatter などの集団通信オペレーションを、TensorCore で実行されるメインの計算と並行して実行できます。これは、TPU7x で非同期コレクティブを使用する場合に推奨される方法です。推奨されるフラグを使用して、最も一般的なコレクティブのオフロードを有効にします。
アクティベーションの再マテリアライズ
アクティベーションの再マテリアライズ(勾配チェックポイントとも呼ばれます)は、モデルの HBM フットプリントを削減するための基本的な手法です。バックワード パスで使用するためにフォワード パスからの中間アクティベーションをすべて HBM に保存するのではなく、いくつかのキー アクティベーション(チェックポイント)のみを保存し、バックワード パスで他のアクティベーションをオンデマンドで再計算します。これにより、計算量の増加(標準の Transformer ブロックの場合、約 25 ~ 30% の追加 FLOP)と引き換えに、メモリを大幅に節約できます。
再マテリアライズをどの程度積極的に適用するかは、プライマリ ボトルネックに完全に依存する重要なチューニング パラメータです。プライマリ ボトルネックはシーケンス長によって異なることがよくあります。
長いシーケンスのワークロード (128k など)の場合: この場合、アクティベーション テンソルのサイズが HBM の主なコンシューマーになります。通常、ワークロードはメモリバウンドです。したがって、積極的な再マテリアライズ ポリシーを適用することは非常に有益です。メモリを節約することで、メモリ不足エラーが発生することなくトレーニングを進めることができ、バッチサイズを大きくすることもできます。再計算の計算オーバーヘッドは、価値のあるトレードオフです。
短いシーケンスのワークロード (8k など)の場合: この場合、アクティベーション メモリはそれほど問題ではなく、ワークロードがコンピューティング バウンドになる可能性が高くなります。 再マテリアライズの計算オーバーヘッドは、非効率性の最大の原因となる可能性があります。
MaxText で再マテリアライズ ポリシーをチューニングする
MaxText では、remat_policy フラグを使用して構成された、一連のプリセット ポリシーとカスタム ポリシーを使用して、再マテリアライズをきめ細かく制御できます。
プリセット ポリシー
MaxText には、次の組み込みポリシーが用意されています。
full: ほぼすべてのものを再マテリアライズする、最も積極的なポリシー。 これにより、HBM の使用量が最小限に抑えられますが、再計算のオーバーヘッドが最大になります。メモリが非常に制約された長いシーケンスのシナリオに最適です。minimal: 最も積極性の低いポリシーで、ほとんどのアクティベーションを保存します。これにより、HBM の使用量が最大になりますが、再計算が最小限に抑えられます。メモリが問題にならない短いシーケンスのコンピューティング バウンドのワークロードに最適です。- 中間ポリシー:
save_dot_with_context_except_mlp、save_qkv_proj、save_out_projなどのオプションを使用すると、コストのかかるドット積オペレーションの出力を選択的にチェックポイントし、安価な要素ごとのオペレーションを再マテリアライズすることで、さまざまなトレードオフを実現できます。
カスタム ポリシー
より詳細な制御を行うには、remat_policy を custom に設定します。これにより、モデルのデコード モジュール内の個々のレイヤの動作を指定できます。各レイヤには、次の 3 つの動作のいずれかを割り当てることができます。
device: アクティベーションは TPU デバイスの HBM に保存されます。remat: アクティベーションは破棄され、バックワード パスで再マテリアライズされます。offload: アクティベーションは HBM から CPU ホストのメモリに移動され、PCIe 転送レイテンシと引き換えに HBM が解放されます。
スコープ付き VMEM チューニング
Flash アテンションなどのカーネル パフォーマンスは、カーネルで選択されたタイルサイズによって異なります。タイルサイズは、使用可能なベクトルメモリ(VMEM)によって制限されます。TPU7x チップの 2 つの TensorCore には、それぞれ 64 MiB のベクトルメモリ(VMEM)があります。この VMEM 容量は、現在のスコープ(スコープ付き VMEM)と将来の重みのプリフェッチの間で分割できます。スコープ付き VMEM を増やすと、カーネルのタイルサイズを大きくできるため、メモリストールが減少し、カーネルのパフォーマンスが向上する可能性があります。xla_tpu_scoped_vmem_limit_kib(LIBTPU_INIT_ARGS 内)を設定して、スコープ付き VMEM サイズを変更できます。これを使用して、カーネル パフォーマンスとエンドツーエンドのパフォーマンスの上限を調べることができます。スコープ付き VMEM サイズを最適化すると、カスタム Pallas カーネルのパフォーマンスに間接的に影響する可能性があります。これは、スコープ付き VMEM を増やすと、カーネル内タイルサイズのハイパーパラメータ検索空間が広がるためです。
Tokamax カーネル
Tokamax は、多くの高度に最適化された TPU カーネルを含む高パフォーマンスの JAX カーネル ライブラリで、ハードウェア固有の一般的な ボトルネックをいくつか解決します。
- Splash アテンション: Splash アテンションは、標準のアテンションの HBM ボトルネックを解消するためのプライマリ アテンション 実装として使用され、TPU で最も効率的なアテンション実装を使用します。
- Megablox グループ化行列乗算(GMM): MoE ワークロードの場合、 Megablox は、不規則なアクティベーション表現で計算することで、 グループ化された行列乗算を効率的に処理します。不規則なディメンションに効率的にマッピングし、LHS の不規則な行グループと対応するエキスパート行列の間で行列乗算を計算します。バッチを固定サイズにパディングする必要はありません。
tune-jaxを使用した経験的チューニング:tune-jaxライブラリには、最適なブロックサイズを経験的に検索するユーティリティがあります。デフォルトのカーネルサイズは最適でないことがよくあります。チューニングにより、ハードウェアに適した VMEM タイルサイズを選択して、ハードウェアの使用率を最大化できます。- 最大ロジットの推定: Tokamax Splash アテンション カーネルは、
max_logit_constの値を設定することでさらに最適化できます。設定すると、アテンションのソフトマックス オペレーション(softmax(Q * KT))中に最大ロジットの削減計算が置き換えられ、計算と同期のオーバーヘッドが削減されます。MaxText では、構成use_max_logits_estimateによって実装されます。これは、None(無効)または浮動小数点値に設定できます。数値オーバーフローを防ぐため、特定のモデルのロジット範囲が推定値と互換性があることを確認してください。 この値を設定する場合は、収束テストをおすすめします。