1. 从“要不要算”到“怎么算”:理解PyTorch梯度追踪的基石
如果你刚开始接触PyTorch,可能会对requires_grad和requires_grad_()这两个看起来几乎一样的名字感到困惑。它们都跟梯度计算有关,但一个是属性,一个是方法,用错了地方,你的模型可能就“学”不动了。简单来说,你可以把requires_grad想象成一张张量的“身份证”,上面写着“我需要计算梯度”或者“我不需要”。而requires_grad_()则是你用来修改这张身份证信息的“笔”,而且是直接在上面涂改的那种。
我刚开始用PyTorch那会儿,就踩过这个坑。当时我想冻结预训练模型的前几层,只训练后面的分类头。我傻乎乎地写了个循环,把每一层参数的requires_grad属性设成False,结果训练时发现内存占用一点没少,反向传播照样慢。后来才发现,我虽然改了属性,但PyTorch的计算图里,那些张量依然被标记为需要梯度。直到我用了requires_grad_()方法,才真正实现了参数的冻结。这个经历让我明白,理解这两者的区别,绝不是纸上谈兵,它直接关系到你模型训练的效率和内存开销。
那么,为什么PyTorch要设计这两套机制呢?这背后其实是声明式与命令式操作的区别。requires_grad属性让你在创建张量时,就声明好它的“命运”——是否参与梯度计算。这就像你出生时就被赋予了一个身份。而requires_grad_()方法则允许你在张量“出生”后,随时改变它的这个身份,是一种更灵活、更动态的操作。理解了这个核心,你就能在模型微调、参数冻结、多任务学习等复杂场景下游刃有余,而不是仅仅停留在“知道有这么个东西”的层面。
2. requires_grad:静态的属性与计算图的构建逻辑
2.1 属性的本质与默认行为
requires_grad是torch.Tensor类的一个布尔属性。它的默认值是False。这意味着,当你创建一个普通的张量时,PyTorch的自动微分引擎(autograd)默认是不会追踪对这个张量的任何操作的,自然也就不会为它计算梯度。这很合理,因为大部分时候,我们的输入数据、中间缓存都不需要梯度,只有模型的参数需要。
import torch
# 创建一个普通张量,默认 requires_grad=False
x = torch.tensor([1.0, 2.0, 3.0])
print(x.requires_grad) # 输出: False
# 创建一个需要梯度追踪的张量
w = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
print(w.requires_grad) # 输出: True
这里有个关键点:requires_grad=True这个属性是可传染的。如果一个操作的输入张量中,只要有一个requires_grad=True,那么该操作的输出张量默认也会requires_grad=True。这是PyTorch动态计算图构建的核心规则之一。比如,你用需要梯度的权重w乘以不需要梯度的数据x,得到的输出y会自动要求梯度,因为它的计算路径依赖于w。
y = w * x # w.requires_grad=True, x.requires_grad=False
print(y.requires_grad) # 输出: True
print(y.grad_fn) # 输出: <MulBackward0 object at 0x...>, 说明它有梯度函数
这个“传染”机制确保了计算图的连贯性。autograd会沿着从输出(通常是损失函数)到输入(模型参数)的路径,反向追踪所有requires_grad=True的张量,并计算它们的梯度。那些requires_grad=False的张量则被排除在这个反向传播路径之外,既节省了计算量,也避免了无谓的梯度计算。
2.2 通过属性设置梯度需求
既然requires_grad是个属性,那我们能不能直接修改它呢?答案是:可以,但不推荐直接赋值。你可能会尝试x.requires_grad = True。在某些简单情况下,这似乎能工作:
x = torch.tensor([1.0, 2.0])
x.requires_grad = True
print(x.requires_grad) # 输出: True
但是,这里隐藏着一个大坑!直接修改属性是一种非原位(non-inplace)且可能不安全的操作。对于从已有张量通过某些操作(如切片、索引)得到的视图(view),或者已经参与过计算图构建的张量,直接修改requires_grad属性可能会导致未定义行为或错误。PyTorch官方文档也建议使用requires_grad_()方法来修改这个属性,因为它经过了更严格的安全性和一致性检查。所以,记住一个原则:创建时声明用参数,创建后


1万+

被折叠的 条评论
为什么被折叠?



