Optimizaciones de rendimiento de TPU7x (Ironwood)
En esta guía, se describen varios métodos para optimizar el rendimiento con TPU7x (Ironwood) mediante la administración eficiente del movimiento de datos entre su sistema de memoria de varios niveles. Esto incluye técnicas como el entrenamiento de baja precisión, la fragmentación, la optimización de la comunicación, la rematerialización de la activación, el ajuste de la memoria virtual con alcance y los kernels de acelerador personalizados.
Para optimizar el rendimiento con TPU7x, primero debes familiarizarte con la arquitectura de Ironwood, específicamente la jerarquía de memoria y la topología de interconexión. Para obtener más información, consulta TPU7x (Ironwood).
Entrenamiento de baja precisión con FP8
FP8 (punto flotante de 8 bits) es un formato de datos numéricos eficiente que se usa principalmente para acelerar el entrenamiento y la inferencia de modelos. Al representar números con 8 bits, en lugar de los formatos estándar de 16 bits (FP16 o BF16) y 32 bits (FP32), las TPU pueden procesar datos mucho más rápido y usar menos memoria.
TPU7x admite la aceleración de hardware integrada para los tipos de datos FP8, lo que ofrece un rendimiento teórico máximo de 4,614 TFLOPS por chip. Esta capacidad puede generar tiempos de entrenamiento de extremo a extremo mucho más rápidos. Para las operaciones compatibles, en particular, las multiplicaciones de matrices densas que son comunes para las cargas de trabajo de IA, el uso de FP8 puede generar mejoras de rendimiento de 1.3 veces en comparación con el entrenamiento BF16 estándar. En comparación con BF16, FP8 duplica los FLOP máximos y reduce a la mitad el espacio en memoria para los pesos y las activaciones. FP8 debe ser una palanca de ajuste principal para las cargas de trabajo vinculadas a la capacidad de procesamiento y las situaciones que están limitadas por la capacidad o el ancho de banda de la memoria.
El uso de FP8 ofrece los siguientes beneficios de rendimiento:
- Reducción de la presión de la memoria de gran ancho de banda (HBM): Un espacio en memoria más pequeño permite que los modelos más grandes o los modelos con cachés de KV más grandes durante la inferencia quepan por completo dentro de los 192 GB de HBM. Esto evita la costosa descarga a una memoria de host más lenta.
- Aumento del tamaño de lote efectivo: Al reducir la memoria requerida para las activaciones, FP8 permite el uso de tamaños de lote más grandes. Esto mejora el paralelismo de datos y puede generar una mayor capacidad de procesamiento y un mejor uso de las unidades de procesamiento.
- Menores requisitos de ancho de banda de memoria: Mover la mitad de la cantidad de datos para cada operación reduce la demanda en la ruta de datos de HBM a MXU. En los sistemas en los que el movimiento de datos es un cuello de botella común, esto ayuda a mantener las MXU saturadas de trabajo.
El uso de FP8 con una degradación nula o limitada en el rendimiento requiere la selección cuidadosa de las técnicas de cuantización. Estas son algunas prácticas recomendadas que debes tener en cuenta para el entrenamiento de FP8:
- Escalamiento de granularidad: Comienza con el escalamiento por tensor como línea de base. Si hay problemas de calidad o rendimiento, cambia al escalamiento por eje. Es posible que el escalamiento de subcanales no sea necesario.
- Modo de escalamiento: El escalamiento dinámico, que calcula los factores de escalamiento en el tiempo de ejecución, es un buen valor predeterminado para mantener la calidad. Si bien el escalamiento estático puede ofrecer un aumento significativo del rendimiento mediante la eliminación de cálculos, requiere una creación de perfiles cuidadosa para determinar los factores de escalamiento correctos y es posible que no sea adecuado para todos los casos de uso, en especial cuando cambian las configuraciones del modelo. Por el contrario, algunos modelos y configuraciones sólidos pueden fijar la escala al límite de FP8 para los pesos o las activaciones, lo que te permite reducir la sobrecarga de cuantización y, al mismo tiempo, mantener la precisión y mejorar el rendimiento.
- Formatos FP8 (E4M3 y E5M2): Un enfoque común y eficaz es usar una combinación de formatos FP8. Por ejemplo, usa E4M3 para los pesos y las activaciones en la propagación hacia adelante para aprovechar la mayor precisión de E4M3 y usa E5M2 para los gradientes en la retropropagación para adaptarse al rango dinámico más amplio de los gradientes.
- Redondeo: El uso de "redondear al par más cercano" (RNE) en lugar del redondeo estocástico para los gradientes puede mantener la calidad y, al mismo tiempo, ofrecer un mejor rendimiento y reproducibilidad.
- Habilitación de FP8 en MaxText:
MaxText admite el entrenamiento de FP8
a través de la biblioteca de cuantización QWIX. Para activar la cuantización, establece la siguiente marca en tu configuración:
use_qwix_quantization=true.
Fragmentación y paralelismo
La fragmentación es el proceso de dividir un modelo grande o sus datos de entrenamiento en partes más pequeñas y distribuirlas en varios chips o núcleos de TPU. Elegir la estrategia de fragmentación adecuada es importante para lograr un alto rendimiento en TPU7x.
Un enfoque ingenuo que maximiza puramente el grado de paralelismo a menudo genera un rendimiento deficiente al volverse dependiente de la comunicación. El mejor enfoque suele ser seleccionar la estrategia de fragmentación más simple que cumpla con las restricciones de memoria, ya que esto minimiza la sobrecarga de comunicación y permite que las unidades de procesamiento se utilicen de manera eficiente.
Antes de seleccionar una estrategia de fragmentación, el primer paso en cualquier esfuerzo de ajuste de rendimiento debe ser un análisis de intensidad aritmética. Este análisis determina si un cálculo determinado está limitado por la capacidad de procesamiento, el ancho de banda de la memoria o el ancho de banda de interconexión. Se calcula como la proporción de operaciones de punto flotante con respecto a los bytes de datos que se deben mover.
Una alta intensidad aritmética indica una carga de trabajo vinculada a la capacidad de procesamiento. Una baja intensidad aritmética sugiere una carga de trabajo vinculada a la memoria o la comunicación, en la que el rendimiento está limitado por la velocidad a la que se pueden mover los datos desde HBM o a través de la red ICI. Este análisis informa el tamaño de lote y la estrategia de fragmentación ideales. Por ejemplo, una carga de trabajo vinculada a la comunicación no se beneficiará de una estrategia de fragmentación que introduzca aún más comunicación, como el paralelismo de tensores de alto grado.
Marco de decisiones de la estrategia de fragmentación
MaxText ofrece una variedad de estrategias de fragmentación. La mejor opción depende de la arquitectura del modelo, la longitud de la secuencia y la necesidad de equilibrar la carga computacional con la sobrecarga de comunicación.
- Paralelismo de datos completamente fragmentados (FSDP): Esta es la estrategia predeterminada preferida para el paralelismo de datos. FSDP fragmenta los pesos, los gradientes y los estados del optimizador del modelo en los dispositivos paralelos de datos. Durante el cálculo, cada dispositivo realiza una operación All-Gather para recuperar los pesos completos necesarios para su microlote local. FSDP es muy eficaz siempre que el tamaño del lote por dispositivo sea lo suficientemente grande como para ocultar la latencia de esta comunicación All-Gather. Para los modelos de Mixture-of-Experts (MoE), el cálculo de intensidad aritmética debe tener en cuenta la dispersión.
- Paralelismo de tensores (TP): TP fragmenta tensores individuales en los dispositivos. Por lo general, los tensores son matrices de peso en perceptrones multicapa (MLP) y bloques de atención. La alta intensidad aritmética del hardware (11.5k) impone un requisito muy alto en las dimensiones del modelo para que TP sea viable en ICI, y el intento de usar TP puede hacer que el sistema esté vinculado a la comunicación.
- Paralelismo de expertos (EP): Esta es la estrategia estándar y necesaria para entrenar modelos MoE. EP fragmenta las capas "expertas" en un conjunto de dispositivos, y se usa un colectivo de comunicación All-to-All para enrutar tokens a su dispositivo experto designado. EP puede ser eficiente si la dimensión MLP del modelo es lo suficientemente grande como para acercarse a la línea de techo.
- Paralelismo de contexto (CP): CP es una estrategia especializada que es esencial para entrenar modelos con longitudes de secuencia muy largas. Su función principal es administrar el consumo de memoria de las activaciones, que crece de forma cuadrática con la longitud de la secuencia y puede exceder la capacidad de HBM. CP fragmenta la dimensión de secuencia de los tensores de activación, lo que permite el uso de un tamaño de lote fraccional por dispositivo. Debido a que CP introduce más comunicación que FSDP, la regla general es usar el grado mínimo de CP necesario para satisfacer las restricciones de memoria y garantizar que el fragmento del eje de lote siga siendo un número entero.
En la siguiente tabla, se asignan tipos de cargas de trabajo comunes a la estrategia de fragmentación óptima:
| Tipo de carga de trabajo | Fragmentación primaria recomendada | Fragmentación secundaria | Cuellos de botella clave | Razones |
|---|---|---|---|---|
| Modelo denso: secuencia corta | FSDP | N/A | Rematerialización, FF Matmuls | FSDP proporciona el mejor equilibrio. Con secuencias cortas, es posible que la memoria de activación no sea una preocupación importante. La clave es un lote global lo suficientemente grande como para ocultar el All-Gather de peso de FSDP. A medida que aumenta el tamaño del lote, también lo hace el tamaño de la activación, y se requiere una política de rematerialización adecuada para garantizar que esta configuración no se quede sin memoria. |
| Modelo denso: secuencia larga | FSDP | CP | Atención flash, memoria de activación | La memoria de activación se convierte en la restricción principal. Se requiere CP para habilitar tamaños de lote fraccionales por dispositivo y evitar problemas de memoria insuficiente (OOM) issues. La atención flash es la fuente dominante de capacidad de procesamiento y tiempo perdido. |
| Modelo MoE: secuencia corta | FSDP + EP | N/A | All-to-All (enrutamiento de expertos), rematerialización | Los modelos MoE requieren EP para fragmentar a los expertos. La comunicación All-to-All para el enrutamiento de tokens es un cuello de botella importante que se debe superponer. La rematerialización también es una fuente importante de desperdicio. |
| Modelo MoE: escala muy grande | FSDP + EP + PP | Paralelismo de modelos (MP) | Todos los cuellos de botella mencionados anteriormente, además de las burbujas de canalización | Para los modelos que exceden la memoria de un solo pod, se necesita PP para fragmentar capas en los pods. Esto introduce la comunicación DCN y las sobrecargas de burbujas de canalización. Esta es una configuración muy compleja que requiere un ajuste cuidadoso |
Optimización de la comunicación
El mecanismo principal para superponer la comunicación y la capacidad de procesamiento en TPU7x se denomina descarga colectiva de SparseCore. La arquitectura de Ironwood incluye unidades SparseCore dedicadas, que actúan como subprocesos de control independientes capaces de administrar el movimiento de datos a través de la estructura ICI. Esto permite que las operaciones de comunicación colectiva (como All-Gather o Reduce-Scatter) se ejecuten en paralelo con los cálculos principales que se realizan en los TensorCores. Este es el método recomendado para los colectivos asíncronos en TPU7x. Usa las marcas recomendadas para habilitar la descarga de los colectivos más comunes.
Rematerialización de la activación
La rematerialización de la activación, también conocida como punto de control de gradiente, es una técnica fundamental para reducir el espacio de HBM de un modelo. En lugar de almacenar todas las activaciones intermedias de la propagación hacia adelante en HBM para usarlas durante la retropropagación, solo guarda algunas activaciones clave (puntos de control) y vuelve a calcular las demás a pedido durante la retropropagación. Esto ahorra una cantidad significativa de memoria a costa de un aumento en la capacidad de procesamiento (aproximadamente un 25 a 30% de FLOP adicionales para un bloque de transformador estándar).
La decisión de qué tan agresivamente aplicar la rematerialización es un parámetro de ajuste fundamental que depende por completo del cuello de botella principal, que a menudo varía con la longitud de la secuencia.
Para cargas de trabajo de secuencia larga (como 128k): En estos casos, el tamaño de los tensores de activación es el consumidor dominante de HBM. Por lo general, la carga de trabajo está vinculada a la memoria. Por lo tanto, aplicar una política de rematerialización agresiva es muy beneficioso. El ahorro de memoria permite que el entrenamiento continúe sin errores de memoria insuficiente y también permite tamaños de lote más grandes, y la sobrecarga computacional de volver a calcular es una compensación que vale la pena.
Para cargas de trabajo de secuencia corta (como 8k): En estos casos, la memoria de activación es mucho menos preocupante, y es más probable que la carga de trabajo esté vinculada a la capacidad de procesamiento. La sobrecarga computacional de la rematerialización puede ser la mayor fuente de ineficiencia.
Ajuste de las políticas de rematerialización en MaxText
MaxText proporciona un control detallado sobre la rematerialización a través de un conjunto de políticas preestablecidas y personalizadas, configuradas con la marca remat_policy.
Políticas preestablecidas
MaxText ofrece las siguientes políticas integradas:
full: La política más agresiva, que rematerializa casi todo. Esto minimiza el uso de HBM, pero maximiza la sobrecarga de volver a calcular. Es ideal para situaciones de secuencia larga con restricciones de memoria extremas.minimal: La política menos agresiva, que almacena la mayoría de las activaciones. Esto maximiza el uso de HBM, pero minimiza la sobrecarga de volver a calcular. Es ideal para cargas de trabajo de secuencia corta vinculadas a la capacidad de procesamiento en las que la memoria no es una preocupación.- Políticas intermedias: Las opciones como
save_dot_with_context_except_mlp,save_qkv_projysave_out_projproporcionan varias compensaciones mediante la creación selectiva de puntos de control de los resultados de las operaciones costosas de producto punto, mientras que se rematerializan las operaciones elementales más económicas.
Políticas personalizadas
Para obtener un mayor nivel de control, puedes establecer remat_policy en custom. Esto te permite especificar el comportamiento de las capas individuales dentro del módulo de decodificación del modelo. A cada capa se le puede asignar uno de los siguientes tres comportamientos:
device: La activación se almacena en HBM en el dispositivo de TPU.remat: La activación se descarta y se rematerializará durante la retropropagación.offload: La activación se mueve de HBM a la memoria del host de la CPU, lo que libera HBM a costa de la latencia de transferencia de PCIe.
Ajuste de VMEM con alcance
El rendimiento del kernel, como la atención flash, depende de los tamaños de mosaico seleccionados en el kernel, cuyo tamaño está limitado por la memoria vectorial disponible (VMEM). Cada uno de los dos TensorCores de un chip TPU7x tiene 64 MiB de memoria vectorial (VMEM). Esta capacidad de VMEM se puede dividir entre el alcance actual (VMEM con alcance) y la recuperación previa de peso futuro. El aumento de VMEM con alcance permite aumentar los tamaños de mosaico en el kernel, lo que puede reducir las detenciones de memoria y aumentar el rendimiento de los kernels. Puedes modificar el tamaño de VMEM con alcance si estableces xla_tpu_scoped_vmem_limit_kib (en LIBTPU_INIT_ARGS), que se puede usar para explorar el rendimiento del kernel, así como los límites de rendimiento de extremo a extremo. La optimización del tamaño de VMEM con alcance puede afectar indirectamente el rendimiento del kernel de Pallas personalizado, ya que el aumento de VMEM con alcance desbloquea un espacio de búsqueda de hiperparámetros más grande para los tamaños de mosaico en el kernel.
Kernels de Tokamax
Tokamax, una biblioteca de kernels de JAX de alto rendimiento con muchos kernels de TPU altamente optimizados, aborda varios cuellos de botella comunes específicos del hardware:
- Atención flash: La atención flash se usa como la implementación de atención principal para eliminar el cuello de botella de HBM de la atención estándar y usa la implementación de atención más eficiente en las TPU.
- Multiplicación de matrices agrupadas de Megablox (GMM): Para las cargas de trabajo de MoE, Megablox controla de manera eficiente las multiplicaciones de matrices agrupadas mediante el cálculo sobre la representación de activaciones irregulares. Se asigna de manera eficiente a la dimensión irregular, y calcula las multiplicaciones de matrices entre grupos irregulares de filas en LHS y la matriz experta correspondiente, lo que evita la necesidad de rellenar lotes a un tamaño fijo.
- Ajuste empírico con
tune-jax: La bibliotecatune-jaxtiene utilidades para realizar búsquedas empíricas de tamaños de bloque óptimos. Los tamaños de kernel predeterminados suelen ser subóptimos. El ajuste permite elegir tamaños de mosaico de VMEM compatibles con el hardware para maximizar el uso del hardware. - Estimación de logits máximos: El kernel de atención flash de Tokamax se puede optimizar aún más si se
establece un valor para
max_logit_const. Si se establece, reemplaza el cálculo de reducción del logit máximo durante la operación softmax de atención (softmax(Q * KT)), lo que reduce parte de la sobrecarga computacional y de sincronización. En MaxText, se implementa mediante la configuraciónuse_max_logits_estimate, que se puede establecer enNone(inhabilitado) o en un valor de punto flotante. Verifica que el rango de logits de tu modelo específico siga siendo compatible con la estimación para evitar el desbordamiento numérico. Se recomienda realizar pruebas de convergencia si se establece este valor.