1. 为什么需要ZeRO Stage 3?
当你尝试在4张A100显卡上训练一个700亿参数的Llama模型时,大概率会遇到显存爆炸(OOM)的问题。传统数据并行(DP)训练中,每张显卡都要完整存储模型参数、梯度和优化器状态。以16位混合精度计算,7B参数模型仅参数就占用14GB显存,加上梯度和Adam优化器的动量/方差状态,总显存需求轻松突破100GB/GPU——这还没算上激活值和临时缓冲区的开销。
ZeRO(Zero Redundancy Optimizer)的核心理念是"分而治之":
- Stage 1:仅分区优化器状态(省2倍显存)
- Stage 2:增加梯度分区(再省2倍)
- Stage 3:终极形态,连模型参数也分区(再省4倍)
实际效果有多夸张?用4张显卡训练7B模型时:
- 传统DP:112GB/GPU → 只能跑7B模型
- ZeRO-3 + offload:20GB/GPU → 能跑40B+模型
2. 参数分区的实现原理
2.1 哈希分区算法
Stage 3的核心是把参数均匀分散到所有GPU。假设有N张显卡:
for param_idx, param in enumerate(model.parameters()):
pid = param_idx % N # 哈希计算分区ID
if pid == local_rank: # 当前GPU负责该分区
param.data = init_value()
else:
param.data = None # 其他分区不占用显存
这就好比4个人分蛋糕:
- GPU0负责第1/4块参数
- GPU1负责第2/4块
- ...


1万+

被折叠的 条评论
为什么被折叠?



