AdamW, SGD和L2正则化以及权重衰减

Python3.8

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

AdamW, SGD和L2正则化以及权重衰减

虽然已经多次使用过AdamW和SGD,但是对它们的原理以及各种超参数并不是很了解,对于两种优化器和L2正则化以及权重衰减的关系也云里雾里,下定决心要写一篇文章梳理清楚。


一、L2正则化和梯度衰减

1、L2正则化

L2正则化通过在损失函数中增加一个权重平方和的偏置项,促使模型倾向于产生较小的权重,从而防止模型过拟合并提升模型的泛化能力。
假设当前时刻 t t t的权重为 x t x_t xt,损失函数为 f ( x t ) f(x_t) f(xt),那么加上L2正则化以后的损失函数为
f r ( x t ) = f ( x t ) + w t 2 ∣ ∣ x t ∣ ∣ 2 2 (1) f_r(x_t) = f(x_t) + \frac{w_t}{2} ||x_t||_2^2 \tag1 fr(xt)=f(xt)+2wt∣∣xt22(1)

为什么L2正则化可以使模型产生较小的权重: 因为给损失函数增加的惩罚和权重的平方有关,权重较小则产生的惩罚也较小,而较小的权重意味着模型在特征空间中更“平滑”,不会对训练数据中的小波动或异常值过于敏感。这种平滑性使得模型在面对新数据时表现更稳定,降低了过拟合的风险。

2、梯度衰减

权重衰减在更新梯度(注意与L2正则的区别)时减去一个关于权重的偏置,的简单公式表达如下
x t + 1 = ( 1 − w ) x t − α t ∇ f ( x t ) (2) x_{t+1} = (1 - w)x_t - \alpha_t \nabla f(x_t) \tag2 xt+1=(1w)xtαtf(xt)(2)
其中 w w w是权重衰减率, α t \alpha_t αt是学习率, ∇ f ( x t ) \nabla f(x_t) f(xt)是梯度

3、L2正则化和梯度衰减

在某些情况下,L2正则化和梯度衰减可以是等价的。对(1)求导,
∇ f r ( x t ) = ∇ f ( x t ) + w t x t (3) \nabla f_r(x_t) = \nabla f(x_t) + w_tx_t \tag3 fr(xt)=f(xt)+wtxt(3)
则当前时刻的权重为
x t + 1 = x t − α t ∇ f r ( x t ) = x t − α t ( ∇ f ( x t ) + w t x t ) x_{t+1} = x_t - \alpha_t \nabla f_r(x_t) = x_t - \alpha_t(\nabla f(x_t) + w_tx_t) xt+1=xtαtfr(xt)=xtαt(f(xt)+wtxt)
整理公式,
x t + 1 = ( 1 − α t w t ) x t − α t ∇ f ( x t ) x_{t+1} = (1-\alpha_tw_t)x_t- \alpha_t \nabla f(x_t) xt+1=(1αtwt)xtαtf(xt)
此时只要令 w t = w α t w_t = \frac{w}{\alpha_t} wt=αtw,则可得到与(3)相等的公式。

二、梯度更新中的梯度衰减

SGD

SGD(stochastic gradient descent)是一种常用的优化方法,在实际应用中主要采用带有动量( m t m_{t} mt)的SGD,下图中紫色部分是用L2正则化实现,绿色部分是直接用梯度衰减实现
SGD
如果在第6行加上偏置项(相当于L2正则化),那么计算动量时(第8行)也会受影响,不过SGD可以通过对w重参数化达到与直接权重衰减同样的效果,使用L2正则化代替权重衰减,会导致正则和学习率耦合。

AdamW

Adam如果采用L2正则化的方式,那么会得到如下公式
在这里插入图片描述
此时 w t x t − 1 w_tx_{t-1} wtxt1也会受到分母(即 v t v_t vt)的影响,当梯度较大时分母也会较大(详细解释见https://towardsdatascience.com/why-adamw-matters-736223f31b5d),而梯度较小时分母也较小,这样导致梯度较大时权重被正则地比梯度较小时要少( w t x t − 1 w_tx_{t-1} wtxt1变小)。并且Adam无法跟SGD一样使用重参数化使L2正则达到和梯度衰减一样的效果,所以AdamW的作者提出在真正更新梯度时再加上权重衰减项(图中绿色部分),将权重衰减从梯度更新中解耦出来。


参考文献

  1. https://towardsdatascience.com/why-adamw-matters-736223f31b5d
  2. FIXING WEIGHT DECAY REGULARIZATION IN ADAM

您可能感兴趣的与本文相关的镜像

Python3.8

Python3.8

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值