torch中.data和.detach()的区别

本文详细解释了PyTorch中为何不能在requires_grad=True的叶子张量上使用inplace操作,以及在求梯度过程中为何不能修改相关张量。同时,对比了.data和.detach()在释放计算历史和安全性上的差异,建议在需要时使用.detach()以确保正确计算梯度。
部署运行你感兴趣的模型镜像

1. 概述

pytorch 中, 有两种情况不能使用 inplace operation

  • 对于 requires_grad=True 的 叶子张量(leaf tensor) 不能使用 inplace operation
  • 对于在求梯度阶段需要用到的张量不能使用 inplace operation

第一种情况: requires_grad=True 的 leaf tensor
类似于x = torch.ones(2, 2, requires_grad=True)中的x是直接创建的,它没有grad_fn,称为叶子节点

import torch
x = torch.ones(1, requires_grad=True)
print("x.data:", x.data)
print("x.data.requires_grad:", x.data.requires_grad)
y = 2*x
x *= 100
y.backward()
print(x)
print(x.grad)

输出:

x.data: tensor([1.])
x.data.requires_grad: False
Traceback (most recent call last):
  File "D:/Desktop/TEMP/TorchLearning/main.py", line 6, in <module>
    x *= 100
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

改正:

x.data *= 100

输出:

x.data: tensor([1.])
x.data.requires_grad: False
tensor([100.], requires_grad=True)
tensor([2.])

第二种情况: 求梯度阶段需要用到的张量

import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True

d = torch.matmul(x, w1)
f = torch.matmul(d, w2)
d[:] = 1 # 因为这句, 代码报错了 RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

f.backward()


以上出错的原因在于 ∂ f ∂ w 2 = g ( d ) \frac{\partial f}{\partial w^{2}}=g(d) w2f=g(d),即 f f f ω 2 \omega_2 ω2 求导是关于 d d d 的函数,在计算 f f f 的时候, d d d 是等于某个值的, f f f ω 2 \omega_2 ω2 的导数和该值相关,但计算完 f f f 后, d d d 值变了,会导致求导错误。

改正:

import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True

d = torch.matmul(x, w1)
d[:] = 1   # 稍微调换一下位置, 就没有问题了
f = torch.matmul(d, w2)
f.backward()

2. .data 与 .detach()的区别

相同点

  • 都和 x 共享同一块数据
  • 都和 x 的计算历史无关
  • requires_grad = False

不同点
y=x.data 在某些情况下不安全, 某些情况, 指的就是 上述 inplace operation 的第二种情况

import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True

d = torch.matmul(x, w1)
d_ = d.data

f = torch.matmul(d, w2)
d_[:] = 1

f.backward()

# 这段代码没有报错, 但是计算上的确错了
# 如果 打印 w2.grad 结果看一下的话, 得到是1, 但是正确的结果应该是4 

上述代码应该报错, 因为: d_ 和 d 共享同一块数据,改 d_ 就相当于 改 d 了;但是, 代码并没有报错 , 但是计算上的确错了

所以, release note 中指出, 如果想要 detach 的效果的话, 还是 detach() 安全一些.

import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True

d = torch.matmul(x, w1)

d_ = d.detach() # 换成 .detach(), 就可以看到 程序报错了...

f = torch.matmul(d, w2)
d_[:] = 1
f.backward()

总结: .data和.detach() 功能相同,但是推荐使用detach(),因为其更安全

参考

关于 pytorch inplace operation, 需要知道的几件事

您可能感兴趣的与本文相关的镜像

ACE-Step

ACE-Step

音乐合成
ACE-Step

ACE-Step是由中国团队阶跃星辰(StepFun)与ACE Studio联手打造的开源音乐生成模型。 它拥有3.5B参数量,支持快速高质量生成、强可控性和易于拓展的特点。 最厉害的是,它可以生成多种语言的歌曲,包括但不限于中文、英文、日文等19种语言

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值