深入解析PyTorch中的kaiming_uniform初始化:原理与实践

1. 从“为什么”开始:聊聊神经网络初始化的那些事儿

如果你刚开始玩深度学习,搭建模型时可能没太在意权重初始化这步,觉得交给框架默认处理就行了。我自己刚开始也是这么想的,直到有一次训练一个稍微深点的网络,死活训不起来,损失要么不降,要么直接爆炸成NaN。折腾了好久,最后发现问题就出在初始化上。从那以后,我才真正明白,好的初始化不是锦上添花,而是模型能否顺利训练的“起跑线”。

你可以把神经网络想象成一个巨大的、多层的信号加工厂。每一层都有很多“工人”(神经元),他们手里拿着“工具”(权重参数),负责对输入信号进行加工,然后传递给下一层。在训练开始前,我们必须给这些“工人”分发初始的“工具”。如果分发的工具(权重)太大(绝对值很大),信号在层间传递时就会被过度放大,经过几层累积就可能“爆炸”,变成巨大的数值,导致计算溢出。反过来,如果工具太小,信号传递几层后就“消失”得无影无踪,后面层的神经元根本接收不到有效的信号,梯度也就传不回去了。这两种情况都会导致训练失败。

所以,初始化的核心目标就一个:让信号(激活值)在网络的前向传播过程中,以及误差信号(梯度)在反向传播过程中,都能保持一个稳定的尺度,既不会指数级增长,也不会指数级衰减。 这个思想在学术上被称为“保持方差稳定”。PyTorch里的kaiming_uniform初始化,就是大名鼎鼎的何恺明大神为了解决这个问题而提出的,特别适合我们后面要用的ReLU及其变种这类激活函数。它不是什么黑魔法,而是一个经过严密数学推导的、非常实用的工程方案。

2. 庖丁解牛:kaiming_uniform的数学原理与推导

知道了“为什么”,我们再来啃“是什么”。kaiming_uniform的全称是“Kaiming均匀分布初始化”。它的核心思想来源于2015年何恺明等人的论文《Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification》。这篇论文指出,对于使用ReLU激活函数的网络,沿用之前针对Sigmoid/Tanh的Xavier初始化方法并不最优,需要一套新的方差标准。

2.1 核心假设与推导起点

推导的起点基于几个合理的假设:

  1. 权重w是独立同分布的随机变量,均值为0。
  2. 输入数据x也是独立同分布的随机变量,均值为0,且与权重w相互独立。
  3. 我们暂时不考虑偏置项b

假设某一层的输入有n_in个神经元(也称为fan_in),输出有n_out个神经元(fan_out)。该层的线性变换为:y = w1*x1 + w2*x2 + ... + w_{n_in} * x_{n_in}

我们的目标是:希望该层输出的方差Var(y),尽可能等于该层输入的方差Var(x)。这样信号强度就能在层间保持稳定。

根据方差的性质,如果wx独立且均值为0,那么Var(y) = n_in * Var(w) * Var(x)。 为了让Var(y) = Var(x),我们就需要:n_in * Var(w) = 1。 因此,权重的方差应该满足:Var(w) = 1 / n_in

这就是mode='fan_in'模式下的核心公式。它保证了在前向传播时,输出的方差不会因为神经元数量n_in而改变。

2.2 为ReLU家族量身定制的增益(Gain)

上面的推导是针对线性激活的。但如果我们在后面接了一个ReLU激活函数a = max(0, y),事情就变了。因为ReLU把一半的负值信号都掐掉了(置为0),这会导致输出的方差减小大约一半。

经过计算,如果y是均值为0、对称分布的数据,经过ReLU后,其方差Var(a)大约是Var(y)的一半。为了补偿这个损失,我们需要在初始化时把权重的方差相应地放大一倍,来保证Var(a)仍然等于Var(x)

所以,新的权重方差要求变成了:Var(w) = 2 / n_in

这个“2”,就是针对标准ReLU的增益(Gain)。它是对激活函数非线性特性的一种补偿系数。

那么对于Leaky ReLU呢?它的公式是f(x) = x if x>=0 else negative_slope * x,其中negative_slope就是我们常说的参数a(一个小常数,比如0.01)。经过更一般的推导,Leaky ReLU对应的增益公式为:

gain = sqrt(2 / (1 + a

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值