pytorch中copy_()、detach()、data()和clone()操作区别小结

本文详细介绍了PyTorch中clone、copy_、detach和data四个操作的区别:clone创建新tensor与源共享数据但梯度累加;copy_先构建目标再复制源,两者内存不同;detach创建独立于计算图的新tensor;data获取不保存梯度的相同数据。理解这些操作有助于高效地进行梯度计算和内存管理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. clone

b = a.clone()

创建一个tensor与源tensor有相同的shape,dtype和device,不共享内存地址,但新tensor(b)的梯度会叠加在源tensor(a)上。需要注意的是,b = a.clone()之后,b并非叶子节点,所以不可以访问它的梯度。

import torch

a = torch.tensor([1.,2.,3.],requires_grad=True)
b = a.clone()


print('===========================不共享地址=========================')
print(type(a), a.data_ptr())
print(type(b), b.data_ptr())
print('===========================clone后分别输出=========================')
print('a: ', a) # a:  tensor([1., 2., 3.], requires_grad=True)
print('b: ', b) #b:  tensor([1., 2., 3.], grad_fn=<CloneBackward0>)

c = a ** 2
d = b ** 3
print('===========================反向传播=========================')
c.sum().backward() # 2* a
print('a.grad: ', a.grad)  #a.grad:  tensor([2., 4., 6.])

d.sum().backward() # 3b**2
print('a.grad: ', a.grad) #a.grad:  tensor([ 5., 16., 33.]) ,会将b梯度累加上去
#print('b.grad: ', b.grad) # b.grad:  None , 已经不属于计算图的叶子,不可以访问b.grad

输出:

===========================不共享地址=========================
<class 'torch.Tensor'> 93899916787840
<class 'torch.Tensor'> 93899917014528
===========================clone后分别输出=========================
a:  tensor([1., 2., 3.], requires_grad=True)
b:  tensor([1., 2., 3.], grad_fn=<CloneBackward0>)
===========================反向传播=========================
a.grad:  tensor([2., 4., 6.])
a.grad:  tensor([ 5., 16., 33.])

2. copy_

b = torch.empty_like(a).copy_(a)

copy_()函数是需要一个目标tensor,也就是说需要先构建b,然后将a拷贝给b,而clone操作则不需要。

copy_()函数完成与clone()函数 类似的功能,但也存在区别。调用copy_()的对象是目标tensor,参数是复制操作from的tensor,最后会返回目标tensor;而clone()的调用对象为源tensor,返回一个新tensor。当然clone()函数也可以采用torch.clone()调用,将源tensor作为参数。

import torch

a = torch.tensor([1., 2., 3.],requires_grad=True)
b = torch.empty_like(a).copy_(a)

print('====================copy_内存不一样======================')
print(a.data_ptr())
print(b.data_ptr()) 
print('====================copy_打印======================')
print(a) 
print(b) 

c = a ** 2
d = b ** 3
print('===================c反向传播=======================')
c.sum().backward()
print(a.grad) # tensor([2., 2., 2.])
print('===================d反向传播=======================')
d.sum().backward()
print(a.grad) # 源tensor梯度累加了
#print(b.grad) # None 

输出:

====================copy_内存不一样======================
94358408685568
94358463065088
====================copy_打印======================
tensor([1., 2., 3.], requires_grad=True)
tensor([1., 2., 3.], grad_fn=<CopyBackwards>)
===================c反向传播=======================
tensor([2., 4., 6.])
===================d反向传播=======================
tensor([ 5., 16., 33.])

3. detach

detach()函数返回与调用对象tensor相关的一个tensor,此新tensor与源tensor共享数据内存(那么tensor的数据必然是相同的),但其requires_grad为False,并且不包含源tensor的计算图信息

import torch

a = torch.tensor([1., 2., 3.],requires_grad=True)
b = a.detach()
print('=========================共享内存==============================')
print(a.data_ptr()) 
print(b.data_ptr()) 
print('=========================原值与detach==============================')
print(a) 
print(b) 

c = a * 2
d = b * 3 #不可以反向传播
print('=========================原值反向传播==============================')
c.sum().backward()
print(a.grad) 
print('=========================detach不可以反向传播==============================')
# d.sum().backward()

输出:

=========================共享内存==============================
94503766034432
94503766034432
=========================原值与detach==============================
tensor([1., 2., 3.], requires_grad=True)
tensor([1., 2., 3.])
=========================原值反向传播==============================
tensor([2., 2., 2.])
=========================detach不可以反向传播==============================

由于b已经从计算图脱离出来,pytorch自然也不跟踪其后续计算过程了。如果想要让b重新加入计算图,只需要b.requires_grad_()

pytorch可以继续跟踪b的计算,但梯度不会从b流回a,梯度被截断。但由于b与a共享内存,a与b的值会一直相等

4. data

data方法是得到一个tensor的数据信息,其返回的信息与上面提到的detach()返回的信息是相同的,也具有 内存相同,不保存梯度 信息的特点。但是data有时候不安全,因为它们共享内存,如果改变一个则另一个也跟着改变,而使用detach时候使用反向传播会报错

import torch
import pdb
x = torch.FloatTensor([[1., 2.]]) #默认x.requires_grad == False,只有float类型可以反向传播
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])

w1.requires_grad = True
w2.requires_grad = True

d = torch.matmul(x, w1) #相乘后,d的requires_grad = True(相加操作也是True)

d_ = d.data # d和d_会共享内存,d_的requires_grad  = False
# d_ = d.detach() #d和d_也会共享内存,但是不能反向传播

f = torch.matmul(d, w2)
d_[:] = 1 #d_修改了值,所以d的值也跟着改变

f.backward() #使用data会获取错误的值,使用detach则报错

参考:

https://zhuanlan.zhihu.com/p/38475183

<think>嗯,用户想了解PyTorch中的detach函数的作用使用场景。我需要先回忆一下相关的知识点。根据之前看到的引用资料,detach函数主要是用来从计算图中分离张量,停止梯度传播。比如引用3提到,detach()会创建一个新张量,不再参与梯度计算。引用4也说明分离后的张量不再有梯度信息。 接下来,用户可能想知道具体的使用场景。根据引用2,detach在控制计算图方面很关键,比如在生成对抗网络(GAN)中,防止生成器更新时影响判别器。或者是在模型评估时,不需要计算梯度的情况下使用detach节省资源。引用3还提到了节省计算资源的例子,比如分离不参与梯度更新的张量,减小计算图规模。 需要解释清楚detach的基本作用,比如停止梯度追踪、节省资源,然后结合代码示例说明。比如在代码中,使用detach后的张量requires_grad变为False,反向传播时不会计算它的梯度。引用3中的示例代码正好可以说明这一点,x的梯度仍然存在,但z的梯度被分离了。 还要注意用户可能遇到的常见问题,比如什么时候该用detach,误用会不会导致梯度错误。比如在需要固定部分参数,或者在将张量转换为numpy时,必须用detach,否则可能影响计算图或报错。 另外,生成相关的问题需要考虑用户可能进一步想了解的内容,比如detachwith torch.no_grad()区别,或者如何在实际项目中应用detach,或者常见错误有哪些。这样用户可以有更全面的理解。 总结一下,回答的结构应该是先介绍作用,再分点详细说明,然后给出代码示例,最后注意事项相关问题。确保使用正确的中文术语,并正确应用LaTeX格式,比如张量用$...$,独立公式用$$...$$。同时,引用资料中的例子要自然地融入回答中,并在适当的位置加上引用标识。</think>### PyTorch中`detach`函数的作用与使用场景 #### 一、`detach`函数的作用 `detach()`是PyTorch中用于**从计算图中分离张量**的核心方法,其主要功能包括: 1. **停止梯度追踪** 通过`detach()`生成的张量与原始计算图完全分离,后续操作不再记录梯度。例如: ```python x = torch.tensor([2.0], requires_grad=True) y = x * 3 z = y.detach() # z脱离计算图,不再参与梯度计算 ``` 反向传播时,`z`的梯度不会被计算,而`y``x`仍保留梯度[^3][^4]。 2. **节省计算资源** 分离不需要梯度更新的张量可以缩小计算图的规模,降低内存计算消耗[^3]。 3. **保护数据完整性** 在需要将张量转换为NumPy数组或用于可视化时,必须使用`detach()`避免因梯度追踪导致的操作错误[^2]。 --- #### 二、典型使用场景 1. **固定模型参数** 在迁移学习中,若需冻结预训练模型的部分参数: ```python for param in model.layer.parameters(): param.detach_() # 直接修改参数张量,使其脱离计算图 ``` 2. **生成对抗网络(GAN)** 训练生成器时,需防止梯度传播到判别器: ```python real_output = discriminator(real_data) fake_data = generator(noise).detach() # 分离生成器输出,避免影响判别器梯度 fake_output = discriminator(fake_data) ``` 3. **模型推理与评估** 在测试阶段禁用梯度计算以提高效率: ```python with torch.no_grad(): output = model(input) # 等价于 output = model(input).detach() ``` 4. **梯度裁剪** 在反向传播前分离部分张量以防止梯度爆炸: ```python loss = criterion(output, target) loss.backward(retain_graph=True) clipped_grad = torch.clamp(model.weight.grad.detach(), -1, 1) # 分离后裁剪 model.weight.grad.copy_(clipped_grad) ``` --- #### 三、代码示例 ```python import torch # 原始张量 x = torch.tensor([2.0], requires_grad=True) y = x**2 + 3*x z = y.detach() # z脱离计算图 # 反向传播 y.backward() print(x.grad) # 输出: tensor([7.]) (因 dy/dx = 2x + 3 = 7) # 对z的操作不影响原始计算图 z += 5 print(z.requires_grad) # 输出: False ``` --- #### 四、注意事项 1. **内存共享** `detach()`生成的新张量与原始张量共享内存,修改其一会影响另一个。若需完全独立,需调用`.detach().clone()`[^4]。 2. **与`with torch.no_grad()`的区别** `detach()`仅针对单个张量,而`with torch.no_grad()`作用于代码块内的所有操作。 3. **反向传播中断** 若计算图中所有张量均被分离,梯度将无法回传,导致训练失败[^2]。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

非晚非晚

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值