Pytorch AMP——自动混合精度训练

文章介绍了PyTorch中的浮点数类型,包括FP32、FP16和BFloat16,强调了FP16在节省内存和加速计算方面的作用。自动混合精度(AMP)模块在支持的GPU上能有效减少显存消耗和加快训练速度。文章详细阐述了AMP的工作机制,包括自动类型转换、GradientScaling以及如何防止underflow和overflow。此外,还提到了在训练过程中如何使用GradScaler进行梯度缩放和更新,并给出了避免训练中出现NAN的建议。

浮点数类型

Pytorch 有好几种类型的浮点数,它们占的内存大小不同,自然也有不同的精度:

torch.FloatTensor (32-bit floating point)
torch.DoubleTensor (64-bit floating point)
torch.HalfTensor (16-bit floating point 1)
torch.BFloat16Tensor (16-bit floating point 2)

Data type dtype
32-bit floating point torch.float32 or torch.float
64-bit floating point torch.float64 or torch.double
16-bit floating point [1] torch.float16 or torch.half
16-bit floating point [2] torch.bfloat16

[1] Referred to as binary16: uses 1 sign, 5 exponent, and 10 significand bits. Useful when precision is important.
[2] Referred to as Brain Floating Point: use 1 sign, 8 exponent and 7 significand bits. Useful when range is important, since it has the same number of exponent bits as float32

半精度浮点数 (FP16) 是一种计算机使用的二进制浮点数数据类型,使用 2 字节 (16 位) 存储,表示范围为 [ − 6.5 e 4 , − 5.9 e − 8 ] ∪ [ 5.9 e − 8 , 6.5 e 4 ] [-6.5e^4, -5.9e^{-8}] \cup [5.9e^{-8}, 6.5e^4] [6.5e4,5.9e8][5.9e8,6.5e4]PyTorch 默认使用单精度浮点数 (FP32) 进行网络模型的计算和权重存储。FP32 在内存中用 4 字节 (32 位) 存储,表示范围为 [ − 3 e 38 , − 1 e − 38 ] ∪ [ 1 e − 38 , 3 e 38 ] [-3e^{38}, -1e^{-38}]

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值