pytorch in-place运算和反向传播报错之间的关系

本文探讨了PyTorch中叶子节点与非叶子节点的概念,重点分析了非叶子节点进行In-place运算与反向传播报错之间的联系,并通过具体示例展示了如何避免因In-place运算引发的错误。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

pytorch中,变量分为叶子节点和非叶子节点,叶子节点没有grad_fn属性,而非叶子节点是由计算产生的,包含有grad_fn属性,对于叶子节点一般是不允许进行in-place运算的,in-place运算会导致叶子节点变为非叶子节点,这时一般会报错,关于叶子节点的in-place运算问题暂且先不讨论。主要讨论对于非叶子节点进行in-place运算与报错之间的联系。

实例

在pytorch中采取反向传播算法来计算梯度,并且应用求导的链式法则,在进行计算的时候,隐性的构建一个计算图,关于非叶子节点进行in-place运算的时候,会将变量的_version属性加1,当在反向传播的时候遇到同一个变量但是_version属性不一致的时候就会报错,如果在求梯度的时候只用到了一次或者没用过,那么即使进行了in-place运算也不会报错,依然可以求出梯度。
举例如下:

import torch
import torch.nn as nn
import math
import torch.nn.functional as F

if __name__ == '__main__':
    Q = torch.randn(2, 2, 3, requires_grad=True)
    V = torch.randn(2, 3, 3, requires_grad=True)
    V1 = V + 1
    a = torch.bmm(Q, V1.permute(0, 2, 1))
    a /= math.sqrt(3)
    V1[:, 0, 1] = 0
    a[:, 1, 2:] = -1e6
    a = F.softmax(a, dim=-1)
    res = torch.bmm(a, V1)
    l = res.sum().backward()
    print(V.grad)

上述代码,在backward()的时候,会因为进行了in-place运算产生报错,原因在于,对torch.bmm(a,V1)操作求梯度的时候,对a的梯度是含有V1变量的,在torch.bmm(Q,V1)求对Q的梯度的时候,也是含有V1变量的,这两个变量的版本不一致,在这两个表达式之间进行了in-place运算,所以导致了报错,但是如果不对Q求梯度,只对V求梯度,那么就不会出现第二次V1变量,也就不会产生in-place运算的错误。

import torch
import torch.nn as nn
import math
import torch.nn.functional as F

if __name__ == '__main__':
    Q = torch.randn(2, 2, 3, requires_grad=False)
    V = torch.randn(2, 3, 3, requires_grad=True)
    V1 = V + 1
    a = torch.bmm(Q, V1.permute(0, 2, 1))
    a /= math.sqrt(3)
    V1[:, 0, 1] = 0
    a[:, 1, 2:] = -1e6
    a = F.softmax(a, dim=-1)
    res = torch.bmm(a, V1)
    l = res.sum().backward()
    print(V.grad)

把Q的requires_grad设为False,就可以成功的求出梯度了。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值