动态计算图是在程序前向传播的过程中构建起来的,主要是用来进行反向传播。相比搭建网络结构时关注每一层的计算方式,计算图主要视角是数据节点(Tensor)。
在计算图构建和反向传播过程中存在一些令人混淆的概念,例如is_leaf、requires_grad、detach()、zero_grad()、 retain_grad()、torch.nograd()。从计算图反向传播的角度去理解这些概念,一切就变的清晰了。
动态图中的反向传播

图1 动态计算图
上图是计算图的示意图:X1和X2是两组输入数据Tensor,P1和P2是网络的权重Tensor,Y和Z是计算的中间结果Tensor,Fn是得到中间结果的计算操作。
-
训练网络的最终目的是更新P1和P2的值,因此需要计算loss关于P1和P2的梯度,为了得到关于P1和P2的梯度,需要依次计算loss关于中间结果Y和Z的梯度。
-
权重Tensor的梯度计算不依赖输入数据且输入数据X不需要更新数值,所以X1和X2不需要计算梯度。
因此默认情况下用户创建的输入数据requirs_grad=False, 网络的权重参数requirs_grad=True
1. 叶子节点
-
由用户的上帝之手直接创造出的Tensor为叶子节点, 这些节点没有记录grad_fn参数(例如输入数据网络权重)。
-
由需要梯度计算的叶子节点通过运算衍生出来的Tensor为非叶子节点,这些节点有grad_fn参数。

本文深入解析PyTorch动态计算图的工作原理,包括叶子节点、requires_grad属性、detach()、zero_grad()、retain_grad()和torch.no_grad()的作用。在训练神经网络时,了解这些概念对于有效地进行反向传播和权重更新至关重要。通过实例展示了如何利用这些工具在反向传播过程中控制梯度计算和权重更新。

1569

被折叠的 条评论
为什么被折叠?



