PyTorch学习笔记:RuntimeError: one of the variables needed for gradient computation has been modified by

 报错信息:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [784, 512]], which is output 0 of TBackward, is at version 3; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Traceback (most recent call last):
  File "E:\Anaconda\envs\torch_c_13\lib\site-packages\IPython\core\interactiveshell.py", line 3343, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-18bbaf295f9c>", line 1, in <module>
    runfile('E:/Code/AEs by PyTorch/AEsingle_train_test_temp.py', wdir='E:/Code/AEs by PyTorch')
  File "E:\SoftWare\PyCharm\PyCharm 2021.2.3\plugins\python\helpers\pydev\_pydev_bundle\pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "E:\SoftWare\PyCharm\PyCharm 2021.2.3\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "E:/Code/AEs by PyTorch/AEsingle_train_test_temp.py", line 205, in <module>
    train_ae_x_h2 = AEtrain(AEmodel2, train_ae_x_h1, 10, "AEmodel2")
  File "E:/Code/AEs by PyTorch/AEsingle_train_test_temp.py", line 95, in AEtrain
    loss.backward()
  File "E:\Anaconda\envs\torch_c_13\lib\site-packages\torch\tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "E:\Anaconda\envs\torch_c_13\lib\site-packages\torch\autograd\__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [784, 512]], which is output 0 of TBackward, is at version 3; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

解决方法:

查看报错信息发现是loss.backward()部分报错,查阅多篇博客后找到解决方法如下

optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()

将loss.backward()函数内的参数retain_graph值设置为True即可成功解决。

retain_graph=True,这个参数的作用是什么,官方定义为:

retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

大意是如果设置为False,计算图中的中间变量在计算完后就会被释放。但是在平时的使用中这个参数默认都为False从而提高效率,和creat_graph的值一样。

参考链接:

另,该问题可能在不同版本的pytorch中出现,可以参考以下链接的回答:

求pytorch大神解答,问题出在哪里 - 虎扑社区 (hupu.com)

也可能会因为代码中出现a=a.xxx这种赋值而报错,可改成b=a.xxx

【PyTorch】RuntimeError: one of the variables needed for gradient computation has been modified by an_ncc1995的博客-优快云博客

 PyTorch中retain_graph参数的作用

Pytorch中retain_graph参数的作用 - Oldpan的个人博客

### 运行时错误 `RuntimeError` 的原因与解决方案 当遇到 `RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation` 错误时,这通常是因为在 PyTorch 中某些张量被进行了就地(inplace)修改操作[^1]。这种行为会破坏反向传播所需的计算图结构,从而导致梯度无法正确计算。 #### 原因分析 该错误的核心在于某个张量在其生命周期中发生了版本号的变化。PyTorch 使用自动求导机制来跟踪张量的操作历史记录,并通过这些记录构建动态计算图以便于后续的梯度计算。如果在此过程中对张量执行了就地修改操作,则可能导致其状态发生变化而失去原始的历史记录[^2]。具体表现为: - 张量的状态与其预期版本不匹配。 - 反向传播尝试访问已被更改的数据或元数据。 #### 解决方案 以下是几种常见的解决方法: 1. **禁用就地操作** 避免使用任何带有 `_` 后缀的方法(如 `.add_()` 或 `.relu_()`),因为它们会对输入张量进行就地修改。改用返回新对象而非改变原有对象的方式实现相同功能。例如,将以下代码片段中的就地赋值替换为克隆副本后再处理: ```python # 替代前 x[:, :] += y # 替代后 x = x + y ``` 2. **启用异常检测模式** 利用 PyTorch 提供的功能开启异常捕获选项可以帮助定位问题所在位置。设置如下参数即可激活此特性: ```python import torch torch.autograd.set_detect_anomaly(True) ``` 当再次触发相同的错误时,系统将会提供更详细的堆栈信息指出具体的失败点位[^3]。 3. **检查模型定义部分是否存在不当操作** 如果确认问题是来源于自定义神经网络层内部逻辑的话,则需仔细审查每一处涉及可训练权重更新的地方是否有潜在风险引入不必要的副作用。比如下面这个例子展示了如何安全地完成阈值裁剪而不影响其他组件的工作流程: ```python max_list_min = ... input_masked = (input > max_list_min).float() x_cloned = x.clone() # 创建独立拷贝避免污染源数据 x_cloned *= input_masked.unsqueeze(-1) # 应用于复制版上实施变换 ``` 4. **调试工具辅助排查** 结合打印语句或者可视化库进一步探索中间结果变化规律也是很有帮助的办法之一。可以周期性输出目标节点的相关属性验证一致性假设成立与否。 ```python print(f'Gradient Function:{ratio.grad_fn}, Version Number:{ratio._version}') ``` 以上措施综合运用能够有效缓解乃至彻底消除此类棘手状况的发生概率。 ---
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值