浮点数类型
Pytorch 有好几种类型的浮点数,它们占的内存大小不同,自然也有不同的精度:
torch.FloatTensor (32-bit floating point)
torch.DoubleTensor (64-bit floating point)
torch.HalfTensor (16-bit floating point 1)
torch.BFloat16Tensor (16-bit floating point 2)
| Data type | dtype |
|---|---|
| 32-bit floating point | torch.float32 or torch.float |
| 64-bit floating point | torch.float64 or torch.double |
| 16-bit floating point [1] | torch.float16 or torch.half |
| 16-bit floating point [2] | torch.bfloat16 |
[1] Referred to as binary16: uses 1 sign, 5 exponent, and 10 significand bits. Useful when precision is important.
[2] Referred to as Brain Floating Point: use 1 sign, 8 exponent and 7 significand bits. Useful when range is important, since it has the same number of exponent bits as float32
半精度浮点数 (FP16) 是一种计算机使用的二进制浮点数数据类型,使用 2 字节 (16 位) 存储,表示范围为 [ − 6.5 e 4 , − 5.9 e − 8 ] ∪ [ 5.9 e − 8 , 6.5 e 4 ] [-6.5e^4, -5.9e^{-8}] \cup [5.9e^{-8}, 6.5e^4] [−6.5e4,−5.9e−8]∪[5.9e−8,6.5e4] 。PyTorch 默认使用单精度浮点数 (FP32) 进行网络模型的计算和权重存储。FP32 在内存中用 4 字节 (32 位) 存储,表示范围为 [ − 3 e 38 , − 1 e − 38 ] ∪ [ 1 e − 38 , 3 e 38 ] [-3e^{38}, -1e^{-38}]

文章介绍了PyTorch中的浮点数类型,包括FP32、FP16和BFloat16,强调了FP16在节省内存和加速计算方面的作用。自动混合精度(AMP)模块在支持的GPU上能有效减少显存消耗和加快训练速度。文章详细阐述了AMP的工作机制,包括自动类型转换、GradientScaling以及如何防止underflow和overflow。此外,还提到了在训练过程中如何使用GradScaler进行梯度缩放和更新,并给出了避免训练中出现NAN的建议。

2万+

被折叠的 条评论
为什么被折叠?



