【Pytorch】backward的 retain_variables=True与retain_graph=True 区别及应用

本文详细解析了PyTorch中backward()函数的retain_graph参数作用,解释了其与已废弃的retain_variables参数的区别,及在多损失函数反向传播场景下如何正确使用以避免计算图被释放,确保多次反向传播的顺利进行。
部署运行你感兴趣的模型镜像

首先 :retain_variables=True 与 retain_graph=True   没有任何区别 

        只是 retain_variables 会在pytorch 新版本中被取消掉,将使用 retain_graph 。

        所以 在使用 retain_variables=True 时报错:

        TypeError: backward() got an unexpected keyword argument 'retain_variables'

 

应用:

两者都是backward()的参数   默认 retain_graph=False,

也就是反向传播之后这个计算图的内存会被释放,这样就没办法进行第二次反向传播了,

所以我们需要设置为 retain_graph=True,

loss_1.backward(retain_graph=True) 
 

retain_graph:

每次 backward() 时,默认会把整个计算图free掉。一般情况下是每次迭代,只需一次 forward() 和一次 backward() ,前向运算forward() 和反向传播backward()是成对存在的,一般一次backward()也是够用的。但是不排除,由于自定义loss等的复杂性,需要一次forward(),多个不同loss的backward()来累积同一个网络的grad,来更新参数。于是,若在当前backward()后,不执行forward() 而可以执行另一个backward(),需要在当前backward()时,指定保留计算图,即backward(retain_graph)。

您可能感兴趣的与本文相关的镜像

PyTorch 2.8

PyTorch 2.8

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[11], line 8 1 """ 2 out.backward(torch.randn(1, 10)) 是对标量或张量进行反向传播时,传入一个 out 形状匹配的梯度张量(gradient tensor)作为权重。 3 简要回答: (...) 6 这个操作会沿着计算图自动计算每个参数的梯度并累加到 .grad 属性中。 7 """ ----> 8 out.backward(torch.randn(1,10)) File ~/anaconda3/envs/torch/lib/python3.8/site-packages/torch/_tensor.py:521, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs) 511 if has_torch_function_unary(self): 512 return handle_torch_function( 513 Tensor.backward, 514 (self,), (...) 519 inputs=inputs, 520 ) --> 521 torch.autograd.backward( 522 self, gradient, retain_graph, create_graph, inputs=inputs 523 ) File ~/anaconda3/envs/torch/lib/python3.8/site-packages/torch/autograd/__init__.py:289, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs) 284 retain_graph = create_graph 286 # The reason we repeat the same comment below is that 287 # some Python versions print out the first line of a multi-line function 288 # calls in the traceback and some print out the last line --> 289 _engine_run_backward( 290 tensors, 291 grad_tensors_, 292 retain_graph, 293 create_graph, 294 inputs, 295 allow_unreachable=True, 296 accumulate_grad=True, 297 ) File ~/anaconda3/envs/torch/lib/python3.8/site-packages/torch/autograd/graph.py:769, in _engine_run_backward(t_outputs, *args, **kwargs) 767 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) 768 try: --> 769 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 770 t_outputs, *args, **kwargs 771 ) # Calls into the C++ engine to run the backward pass 772 finally: 773 if attach_logging_hooks: RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
最新发布
10-05
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值