Pengoptimalan performa TPU7x (Ironwood)
Panduan ini menjelaskan beberapa metode untuk mengoptimalkan performa dengan TPU7x (Ironwood) dengan mengelola pergerakan data secara efisien antara sistem memori bertingkatnya. Hal ini mencakup teknik seperti pelatihan presisi rendah, sharding, pengoptimalan komunikasi, rematerialisasi aktivasi, penyetelan VMEM cakupan, dan kernel akselerator kustom.
Untuk mengoptimalkan performa dengan TPU7x, Anda harus terlebih dahulu memahami arsitektur Ironwood, khususnya hierarki memori dan topologi interkoneksi. Untuk mengetahui informasi selengkapnya, lihat TPU7x (Ironwood).
Pelatihan presisi rendah dengan FP8
FP8 (floating point 8-bit) adalah format data numerik yang efisien dan digunakan terutama untuk mempercepat pelatihan dan inferensi model. Dengan merepresentasikan angka menggunakan 8 bit – bukan format 16-bit standar (FP16 atau BF16) dan 32-bit (FP32) – TPU dapat memproses data secara signifikan lebih cepat dan menggunakan lebih sedikit memori.
TPU7x mendukung akselerasi hardware bawaan untuk jenis data FP8, yang menawarkan performa teoretis puncak sebesar 4614 TFLOPS per chip. Kemampuan ini dapat menghasilkan waktu pelatihan end-to-end yang jauh lebih cepat. Untuk operasi yang kompatibel, terutama perkalian matriks padat yang umum untuk workload AI, penggunaan FP8 dapat menghasilkan peningkatan performa sebesar 1,3x dibandingkan pelatihan BF16 standar. Dibandingkan dengan BF16, FP8 menggandakan FLOP puncak dan mengurangi setengah footprint memori untuk bobot dan aktivasi. FP8 harus menjadi tuas penyetelan utama untuk workload yang terikat komputasi dan skenario yang dibatasi oleh kapasitas atau bandwidth memori.
Penggunaan FP8 menawarkan manfaat performa berikut:
- Mengurangi tekanan memori bandwidth tinggi (HBM): Footprint memori yang lebih kecil memungkinkan model yang lebih besar, atau model dengan cache KV yang lebih besar selama inferensi, agar sepenuhnya sesuai dengan HBM 192 GB. Hal ini menghindari offload yang mahal ke memori host yang lebih lambat.
- Meningkatkan ukuran batch efektif: Dengan mengurangi memori yang diperlukan untuk aktivasi, FP8 memungkinkan penggunaan ukuran batch yang lebih besar. Hal ini meningkatkan paralelisme data dan dapat menghasilkan throughput yang lebih tinggi serta pemanfaatan unit komputasi yang lebih baik.
- Persyaratan bandwidth memori yang lebih rendah: Memindahkan setengah jumlah data untuk setiap operasi mengurangi permintaan pada jalur data HBM ke MXU. Pada sistem yang pergerakan datanya merupakan bottleneck umum, hal ini membantu menjaga MXU tetap jenuh dengan pekerjaan.
Penggunaan FP8 dengan degradasi performa nol atau terbatas memerlukan pemilihan teknik kuantisasi yang cermat. Berikut beberapa praktik terbaik yang perlu dipertimbangkan untuk pelatihan FP8:
- Granularitas penskalaan: Mulai dengan penskalaan per tensor sebagai dasar. Jika ada masalah kualitas atau performa, beralihlah ke penskalaan per sumbu. Penskalaan sub-channel mungkin tidak diperlukan.
- Mode penskalaan: Penskalaan dinamis, yang menghitung faktor penskalaan saat runtime, adalah default yang baik untuk mempertahankan kualitas. Meskipun penskalaan statis dapat menawarkan peningkatan performa yang signifikan dengan menghilangkan komputasi, penskalaan statis memerlukan pembuatan profil yang cermat untuk menentukan faktor penskalaan yang benar dan mungkin tidak cocok untuk semua kasus penggunaan, terutama saat konfigurasi model berubah. Sebaliknya, beberapa model dan konfigurasi yang kuat dapat memperbaiki skala ke batas FP8 untuk bobot atau aktivasi, sehingga Anda dapat mengurangi overhead kuantisasi sekaligus mempertahankan akurasi dan meningkatkan performa.
- Format FP8 (E4M3 dan E5M2): Pendekatan umum dan efektif adalah menggunakan campuran format FP8. Misalnya, gunakan E4M3 untuk bobot dan aktivasi dalam forward pass untuk memanfaatkan presisi E4M3 yang lebih tinggi, dan gunakan E5M2 untuk gradien dalam backward pass untuk mengakomodasi rentang dinamis gradien yang lebih luas.
- Pembulatan: Menggunakan "round to nearest even" (RNE) dan bukan pembulatan stokastik untuk gradien dapat mempertahankan kualitas sekaligus menawarkan performa dan kemampuan reproduksi yang lebih baik.
- Mengaktifkan FP8 di MaxText:
MaxText mendukung pelatihan FP8
melalui library kuantisasi QWIX. Untuk mengaktifkan kuantisasi, tetapkan flag berikut dalam konfigurasi Anda:
use_qwix_quantization=true.
Sharding dan paralelisme
Sharding adalah proses membagi model besar atau data pelatihannya menjadi bagian yang lebih kecil dan mendistribusikannya ke beberapa chip atau core TPU. Memilih strategi sharding yang tepat penting untuk mencapai performa tinggi di TPU7x.
Pendekatan naif yang memaksimalkan tingkat paralelisme secara murni sering kali menghasilkan performa yang buruk karena terikat komunikasi. Pendekatan terbaik sering kali adalah memilih strategi sharding paling sederhana yang memenuhi batasan memori, karena hal ini meminimalkan overhead komunikasi dan memungkinkan unit komputasi digunakan secara efisien.
Sebelum memilih strategi sharding, langkah pertama dalam upaya penyetelan performa apa pun harus berupa analisis intensitas aritmetika. Analisis ini menentukan apakah komputasi tertentu dibatasi oleh komputasi, bandwidth memori, atau bandwidth interkoneksi. Analisis ini dihitung sebagai rasio operasi floating point terhadap byte data yang harus dipindahkan.
Intensitas aritmetika yang tinggi menunjukkan workload yang terikat komputasi. Intensitas aritmetika yang rendah menunjukkan workload yang terikat memori atau komunikasi, dengan performa yang dibatasi oleh kecepatan data dapat dipindahkan dari HBM atau di seluruh jaringan ICI. Analisis ini memberi tahu ukuran batch dan strategi sharding yang ideal. Misalnya, workload yang terikat komunikasi tidak akan mendapatkan manfaat dari strategi sharding yang memperkenalkan lebih banyak komunikasi, seperti paralelisme tensor tingkat tinggi.
Kerangka keputusan strategi sharding
MaxText menawarkan berbagai strategi sharding. Pilihan optimal bergantung pada arsitektur model, panjang urutan, dan kebutuhan untuk menyeimbangkan beban komputasi dengan overhead komunikasi.
- Fully Sharded Data Parallelism (FSDP): Ini adalah strategi default pilihan untuk paralelisme data. FSDP membagi bobot model, gradien, dan status pengoptimal di seluruh perangkat paralel data. Selama komputasi, setiap perangkat melakukan operasi All-Gather untuk mengambil bobot penuh yang diperlukan untuk microbatch lokalnya. FSDP sangat efektif selama ukuran batch per perangkat cukup besar untuk menyembunyikan latensi komunikasi All-Gather ini. Untuk model Mixture-of-Experts (MoE), perhitungan intensitas aritmetika harus memperhitungkan sparsitas.
- Tensor Parallelism (TP): TP membagi tensor individual di seluruh perangkat. Biasanya, tensor adalah matriks bobot dalam multilayer perceptron (MLP) dan blok perhatian. Intensitas aritmetika hardware yang tinggi (11,5k) memberikan persyaratan yang sangat tinggi pada dimensi model untuk membuat TP layak digunakan melalui ICI, dan mencoba menggunakan TP dapat menyebabkan sistem terikat komunikasi.
- Expert Parallelism (EP): Ini adalah strategi standar dan diperlukan untuk melatih model MoE. EP membagi lapisan "pakar" di seluruh kumpulan perangkat, dan kolektif komunikasi All-to-All digunakan untuk merutekan token ke perangkat pakar yang ditentukan. EP dapat efisien jika dimensi MLP model cukup besar untuk mendekati roofline.
- Context Parallelism (CP): CP adalah strategi khusus yang penting untuk melatih model dengan panjang urutan yang sangat panjang. Fungsi utamanya adalah mengelola konsumsi memori aktivasi, yang tumbuh secara kuadrat dengan panjang urutan dan dapat melebihi kapasitas HBM. CP membagi dimensi urutan tensor aktivasi, yang memungkinkan penggunaan ukuran batch per perangkat pecahan. Karena CP memperkenalkan lebih banyak komunikasi daripada FSDP, aturan umumnya adalah menggunakan tingkat CP minimum yang diperlukan untuk memenuhi batasan memori dan memastikan shard sumbu batch tetap berupa bilangan bulat.
Tabel berikut memetakan jenis workload umum ke strategi sharding yang optimal:
| Jenis workload | Sharding utama yang direkomendasikan | Sharding sekunder | Bottleneck utama | Alasan |
|---|---|---|---|---|
| Model padat - urutan pendek | FSDP | T/A | Rematerialisasi, FF Matmuls | FSDP memberikan keseimbangan terbaik. Dengan urutan pendek, memori aktivasi mungkin bukan masalah utama. Kuncinya adalah batch global yang cukup besar untuk menyembunyikan All-Gather bobot FSDP. Saat ukuran batch meningkat, ukuran aktivasi meningkat, dan kebijakan rematerialisasi yang sesuai diperlukan untuk memastikan konfigurasi ini tidak kehabisan memori. |
| Model padat - urutan panjang | FSDP | CP | Perhatian flash, memori aktivasi | Memori aktivasi menjadi batasan utama. CP diperlukan untuk mengaktifkan ukuran batch per perangkat pecahan dan menghindari masalah kehabisan memori (OOM) . Perhatian flash adalah sumber komputasi dan waktu yang terbuang yang dominan. |
| Model MoE - urutan pendek | FSDP + EP | T/A | All-to-All (Perutean pakar), rematerialisasi | Model MoE memerlukan EP untuk membagi pakar. Komunikasi All-to-All komunikasi untuk perutean token adalah bottleneck utama yang harus tumpang-tindih. Rematerialisasi juga merupakan sumber pemborosan yang signifikan. |
| Model MoE - skala sangat besar | FSDP + EP + PP | Paralelisme model (MP) | Semua bottleneck yang disebutkan sebelumnya, ditambah bubble pipeline | Untuk model yang melebihi memori satu pod, PP diperlukan untuk membagi lapisan di seluruh pod. Hal ini memperkenalkan overhead komunikasi DCN dan pipeline bubble. Ini adalah konfigurasi yang sangat kompleks dan memerlukan penyetelan yang cermat. |
Pengoptimalan komunikasi
Mekanisme utama untuk komunikasi dan komputasi yang tumpang-tindih di TPU7x disebut SparseCore Collective Offloading. Arsitektur Ironwood mencakup unit SparseCore khusus, yang bertindak sebagai thread kontrol independen yang mampu mengelola pergerakan data melalui fabric ICI. Hal ini memungkinkan operasi komunikasi kolektif (seperti All-Gather atau Reduce-Scatter) dieksekusi secara paralel dengan komputasi utama yang terjadi di TensorCore. Ini adalah metode yang direkomendasikan untuk kolektif asinkron di TPU7x. Gunakan flag yang direkomendasikan untuk mengaktifkan offload bagi kolektif yang paling umum.
Rematerialisasi aktivasi
Rematerialisasi aktivasi, yang juga dikenal sebagai checkpointing gradien, adalah teknik mendasar untuk mengurangi footprint HBM model. Daripada menyimpan semua aktivasi perantara dari forward pass di HBM untuk digunakan selama backward pass, teknik ini hanya menyimpan beberapa aktivasi utama (checkpoint) dan menghitung ulang aktivasi lainnya sesuai permintaan selama backward pass. Hal ini menghemat sejumlah besar memori dengan mengorbankan peningkatan komputasi (sekitar 25-30% FLOP tambahan untuk blok transformer standar).
Keputusan tentang seberapa agresif penerapan rematerialisasi adalah parameter penyetelan penting yang sepenuhnya bergantung pada bottleneck utama, yang sering kali bervariasi dengan panjang urutan.
Untuk workload urutan panjang (seperti 128k): Dalam kasus ini, ukuran tensor aktivasi adalah konsumen HBM yang dominan. Workload biasanya terikat memori. Oleh karena itu, menerapkan kebijakan rematerialisasi yang agresif sangat bermanfaat. Penghematan memori memungkinkan pelatihan dilanjutkan tanpa error kehabisan memori dan juga memungkinkan ukuran batch yang lebih besar, dan overhead komputasi penghitungan ulang adalah trade-off yang bermanfaat.
Untuk workload urutan pendek (seperti 8k): Dalam kasus ini, memori aktivasi jauh lebih tidak menjadi masalah, dan workload lebih cenderung terikat komputasi. Overhead komputasi rematerialisasi dapat menjadi satu-satunya sumber inefisiensi terbesar.
Menyetel kebijakan rematerialisasi di MaxText
MaxText memberikan kontrol terperinci atas rematerialisasi melalui serangkaian kebijakan preset dan kustom, yang dikonfigurasi menggunakan flag remat_policy.
Kebijakan preset
MaxText menawarkan kebijakan bawaan berikut:
full: Kebijakan paling agresif, yang melakukan rematerialisasi hampir semuanya. Kebijakan ini meminimalkan penggunaan HBM, tetapi memaksimalkan overhead penghitungan ulang. Ideal untuk skenario urutan panjang yang sangat dibatasi memori.minimal: Kebijakan paling tidak agresif, yang menyimpan sebagian besar aktivasi. Kebijakan ini memaksimalkan penggunaan HBM, tetapi meminimalkan penghitungan ulang. Paling cocok untuk workload urutan pendek yang terikat komputasi dan memori tidak menjadi masalah.- Kebijakan perantara: Opsi seperti
save_dot_with_context_except_mlp,save_qkv_proj, dansave_out_projmemberikan berbagai trade-off dengan melakukan checkpointing secara selektif pada output operasi produk titik yang mahal sekaligus melakukan rematerialisasi operasi per elemen yang lebih murah.
Kebijakan kustom
Untuk tingkat kontrol yang lebih tinggi, Anda dapat menetapkan remat_policy ke custom. Hal ini memungkinkan Anda menentukan perilaku untuk setiap lapisan dalam modul dekode model. Setiap lapisan dapat ditetapkan ke salah satu dari tiga perilaku:
device: Aktivasi disimpan di HBM pada perangkat TPU.remat: Aktivasi akan dihapus dan akan dirematerialisasi selama backward pass.offload: Aktivasi dipindahkan dari HBM ke memori host CPU, sehingga mengosongkan HBM dengan mengorbankan latensi transfer PCIe.
Penyetelan VMEM cakupan
Performa kernel, seperti perhatian flash, bergantung pada ukuran petak yang dipilih dalam kernel, yang ukurannya dibatasi oleh memori vektor (VMEM) yang tersedia. Setiap dari dua TensorCore dalam chip TPU7x memiliki memori vektor (VMEM) sebesar 64 MiB. Kapasitas VMEM ini dapat dibagi antara cakupan saat ini (VMEM cakupan) dan pengambilan data awal bobot di masa mendatang. Meningkatkan VMEM cakupan memungkinkan peningkatan ukuran petak dalam kernel, yang berpotensi mengurangi penundaan memori dan meningkatkan performa kernel. Anda dapat mengubah ukuran VMEM cakupan dengan menetapkan xla_tpu_scoped_vmem_limit_kib (di LIBTPU_INIT_ARGS), yang dapat digunakan untuk menjelajahi performa kernel serta batas performa end-to-end. Mengoptimalkan ukuran VMEM cakupan dapat secara tidak langsung memengaruhi performa kernel Pallas kustom karena meningkatkan VMEM cakupan akan membuka ruang penelusuran hiperparameter yang lebih besar untuk ukuran petak dalam kernel.
Kernel Tokamax
Tokamax, library kernel JAX berperforma tinggi dengan banyak kernel TPU yang sangat dioptimalkan, mengatasi beberapa bottleneck khusus hardware umum:
- Perhatian splash: Perhatian splash digunakan sebagai penerapan perhatian utama untuk menghilangkan bottleneck HBM perhatian standar dan menggunakan penerapan perhatian yang paling efisien di TPU.
- Perkalian Matriks Dikelompokkan Megablox (GMM): Untuk workload MoE, Megablox menangani perkalian matriks yang dikelompokkan secara efisien dengan melakukan komputasi atas representasi aktivasi yang tidak rata. Representasi ini memetakan secara efisien dimensi yang tidak rata, menghitung perkalian matriks antara kelompok baris yang tidak rata di LHS, dan matriks pakar yang sesuai, sehingga menghindari kebutuhan untuk menambahkan padding pada batch ke ukuran tetap.
- Penyetelan empiris dengan
tune-jax: Librarytune-jaxmemiliki utilitas untuk melakukan penelusuran empiris untuk ukuran blok yang optimal. Ukuran kernel default sering kali tidak optimal; penyetelan memungkinkan pemilihan ukuran petak VMEM yang ramah hardware untuk memaksimalkan penggunaan hardware. - Perkiraan logit maks: Kernel perhatian Splash Tokamax dapat dioptimalkan lebih lanjut dengan
menetapkan nilai untuk
max_logit_const. Jika ditetapkan, nilai ini akan menggantikan perhitungan pengurangan logit maks selama operasi softmax perhatian (softmax(Q * KT)), sehingga mengurangi beberapa overhead komputasi dan sinkronisasi. Di MaxText, hal ini diterapkan oleh konfigurasiuse_max_logits_estimate, yang dapat ditetapkan keNone(dinonaktifkan) atau nilai floating point. Pastikan rentang logit model spesifik Anda tetap kompatibel dengan perkiraan untuk mencegah overflow numerik. Pengujian konvergensi direkomendasikan jika nilai ini ditetapkan.