
在pytorch中,当我们构建一个神经网络并执行前向传播后,可以通过loss.backward()触发反向传播,计算模型参数的梯度。这些梯度是优化器更新参数的基础。然而,有时为了调试或深入理解模型的内部工作机制,我们可能需要查看非叶子节点(即计算图中的中间张量)的梯度。
PyTorch的自动微分系统(Autograd)默认情况下,在反向传播完成后会释放中间张量的梯度,以节省内存。此外,torch.nn.Module提供的register_full_backward_hook等钩子函数主要设计用于捕获模块输入和输出的梯度,或与模块参数相关的梯度,而非直接用于任意中间张量的梯度。
错误的尝试:使用钩子获取中间张量梯度许多开发者可能会尝试使用模块的后向钩子来捕获中间张量的梯度,例如以下代码所示:
import torch
import torch.nn as nn
class func_NN(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Parameter(torch.rand(1))
self.b = nn.Parameter(torch.rand(1))
def forward(self, inp):
mul_x = torch.cos(self.a.view(-1, 1) * inp)
sum_x = mul_x - self.b
return sum_x
# 钩子函数
def backward_hook(module, grad_input, grad_output):
print("module: ", module)
print("inp_grad: ", grad_input)
print("out_grad: ", grad_output)
# 模拟训练过程
a_true = torch.Tensor([0.5])
b_true = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a_true * x + (0.1**0.5) * torch.randn_like(x) * (0.001) + b_true
inp = torch.linspace(-1, 1, 10)
foo = func_NN()
# 注册一个全反向传播钩子
handle_ = foo.register_full_backward_hook(backward_hook)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(foo.parameters(), lr=0.001)
print("--- 尝试使用钩子 ---")
for i in range(1): # 只运行一次以观察输出
optimizer.zero_grad()
output = foo.forward(inp=inp)
loss = loss_fn(y, output)
loss.backward()
optimizer.step()
handle_.remove() # 移除钩子 上述代码中的backward_hook会打印func_NN模块的输入梯度和输出梯度,但它并不能直接提供mul_x或sum_x这些模块内部计算产生的中间张量的梯度。这是因为register_full_backward_hook捕获的是模块作为整体的输入和输出梯度,而不是其内部任意子表达式的梯度。
正确的方法:使用 retain_grad() 捕获中间张量梯度要获取中间张量的梯度,我们需要明确告诉PyTorch的Autograd系统不要在反向传播后释放这些张量的梯度。这可以通过调用张量的retain_grad()方法来实现。此外,由于局部变量在函数结束后会超出作用域,我们需要将这些中间张量的引用存储在某个地方(例如作为nn.Module的属性),以便在反向传播完成后访问它们的.grad属性。
Post AI
博客文章AI生成器
50
查看详情
以下是修改后的代码示例:
import torch
import torch.nn as nn
class func_NN_RetainGrad(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Parameter(torch.rand(1))
self.b = nn.Parameter(torch.rand(1))
# 用于存储中间张量的引用
self.mul_x = None
self.sum_x = None
def forward(self, inp):
mul_x = torch.cos(self.a.view(-1, 1) * inp)
sum_x = mul_x - self.b
# 关键步骤1: 对需要保留梯度的中间张量调用 retain_grad()
mul_x.retain_grad()
sum_x.retain_grad()
# 关键步骤2: 存储中间张量的引用,以便反向传播后访问其 .grad 属性
self.mul_x = mul_x
self.sum_x = sum_x
return sum_x
# 模拟数据
a_true = torch.Tensor([0.5])
b_true = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a_true * x + (0.1**0.5) * torch.randn_like(x) * (0.001) + b_true
inp = torch.linspace(-1, 1, 10)
foo_retain = func_NN_RetainGrad()
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(foo_retain.parameters(), lr=0.001)
print("\n--- 使用 retain_grad() 获取中间张量梯度 ---")
# 执行一次前向传播和反向传播
output = foo_retain.forward(inp=inp)
loss = loss_fn(y, output)
loss.backward() # 执行反向传播
# 反向传播完成后,现在可以访问中间张量的 .grad 属性
print("mul_x 的梯度:\n", foo_retain.mul_x.grad)
print("sum_x 的梯度:\n", foo_retain.sum_x.grad)
# 验证参数梯度是否正常计算
print("参数 a 的梯度:\n", foo_retain.a.grad)
print("参数 b 的梯度:\n", foo_retain.b.grad) 在这个修正后的示例中:
- 我们在forward方法中计算mul_x和sum_x之后,立即调用了它们的retain_grad()方法。这告诉Autograd在反向传播过程中不要清除这些张量的梯度信息。
- 我们将mul_x和sum_x赋值给self.mul_x和self.sum_x,将它们的引用存储在模块实例中。这样,即使forward方法执行完毕,我们仍然可以通过foo_retain.mul_x和foo_retain.sum_x访问到这些张量。
- 在调用loss.backward()之后,这些被保留的中间张量的梯度就可以通过它们的.grad属性被访问到并打印出来。
- 内存消耗: retain_grad()会阻止Autograd释放中间张量的梯度,这会增加内存消耗。因此,应仅在调试或特定需求时使用,并在不再需要时移除或避免在生产代码中大量使用。
- 适用场景: retain_grad()适用于获取计算图中的任意中间张量的梯度。而模块钩子(如register_full_backward_hook)更适用于监控模块的输入/输出梯度,或者在模块级别执行一些操作。参数钩子(如param.register_hook)则用于直接修改或观察参数的梯度。
- 调试工具: retain_grad()是一个强大的调试工具,可以帮助我们理解梯度流,发现潜在的梯度消失或梯度爆炸问题,或者验证自定义反向传播的正确性。
- 何时调用: 必须在执行loss.backward()之前调用retain_grad()。如果在一个张量上多次调用retain_grad(),不会有额外影响。
- 叶子节点: 对于叶子节点(如nn.Parameter),其梯度默认会被保留(如果requires_grad=True),无需调用retain_grad()。
在PyTorch中获取非叶子节点(中间张量)的梯度,不能直接依赖于nn.Module的后向钩子。正确的做法是利用张量的retain_grad()方法,并在前向传播时将这些中间张量存储为模块的属性。这样,在反向传播完成后,我们就可以通过访问这些属性的.grad字段来获取其梯度。理解并正确使用retain_grad()对于深入调试和优化PyTorch模型至关重要,但同时也要注意其可能带来的内存开销。
以上就是PyTorch中获取中间张量梯度值的实用指南的详细内容,更多请关注知识资源分享宝库其它相关文章!
相关标签: 工具 ai 神经网络 pytorch 作用域 cos 局部变量 作用域 pytorch 大家都在看: Python怎么升级pip和第三方库_pip包管理工具升级指南 python virtualenv和venv有什么区别_python虚拟环境工具virtualenv与venv的对比 Python怎么获取当前工作目录_Python获取当前路径操作指南 向 Plotly Dash 应用图表工具栏添加全屏图标 TensorFlow中高效实现多项式回归:从深度网络到特征工程






发表评论:
◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。