detach detach_ pytorch

本文详细介绍了PyTorch中detach和detach_的功能及区别。detach返回一个从计算图分离的新Variable,而detach_则直接在原地修改Variable,将其变为叶子节点。这两种方法常用于控制网络中哪些部分参与梯度计算。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

pytorch中的detach和detach_
pytorch 的 Variable 对象中有两个方法,detach和 detach_ :

detach
官方文档中,对这个方法是这么介绍的。

返回一个新的从当前图中分离的 Variable。
返回的 Variable 永远不会需要梯度
如果 被 detach 的Variable volatile=True, 那么 detach 出来的 volatile 也为 True
还有一个注意事项,即:返回的 Variable 和 被 detach 的Variable 指向同一个 tensor

import torch
from torch.nn import init
from torch.autograd import Variable
t1 = torch.FloatTensor([1., 2.])
v1 = Variable(t1)
t2 = torch.FloatTensor([2., 3.])
v2 = Variable(t2)
v3 = v1 + v2
v3_detached = v3.detach()
v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值
print(v3, v3_detached)    # v3 中tensor 的值也会改变

detach的源码:

1 # detach 的源码
2 def detach(self):
3     result = NoGrad()(self)  # this is needed, because it merges version counters
4     result._grad_fn = None
5     return result

detach_
官网给的解释是:将 Variable 从创建它的 graph 中分离,把它作为叶子节点。

从源码中也可以看出这一点

将 Variable 的grad_fn 设置为 None,这样,BP 的时候,到这个 Variable 就找不到 它的 grad_fn,所以就不会再往后BP了。
将 requires_grad 设置为 False。这个感觉大可不必,但是既然源码中这么写了,如果有需要梯度的话可以再手动 将 requires_grad 设置为 true

# detach_ 的源码
def detach_(self):
    """Detaches the Variable from the graph that created it, making it a
    leaf.
    """
    self._grad_fn = None
    self.requires_grad = False

能用来干啥
可以对部分网络求梯度。

如果我们有两个网络 , 两个关系是这样的 现在我们想用 来为B网络的参数来求梯度,但是又不想求A网络参数的梯度。我们可以这样:

# y=A(x), z=B(y) 求B中参数的梯度,不求A中参数的梯度
# 第一种方法
y = A(x)
z = B(y.detach())
z.backward()
  
# 第二种方法
y = A(x)
y.detach_()
z = B(y)
z.backward()

在这种情况下,detach 和 detach_ 都可以用。但是如果 你也想用 y

### PyTorch `detach` 与 `detach_` 的用法及区别 #### 1. **`detach()` 方法** `detach()` 是一种非原地操作 (out-of-place operation),它会返回一个新的张量,该张量与原始张量共享相同的底层数据存储,但新张量不再属于原有的计算图的一部分[^3]。这意味着新的张量不会记录任何对其的操作,也不会参与到梯度计算中。 以下是 `detach()` 的主要特点: - 不会影响原始张量的计算图。 - 新张量的 `requires_grad` 属性始终为 `False`,无论原始张量是否有梯度需求。 - 如果需要修改分离后的张量的数据而不影响原始张量的梯度流,则应使用此方法。 示例代码如下: ```python import torch x = torch.tensor([2.0, 3.0], requires_grad=True) y = x * 3 z = y.detach() # 创建了一个脱离计算图的新张量 z print("z requires_grad:", z.requires_grad) # 输出 False ``` --- #### 2. **`detach_()` 方法** `detach_()` 则是一种原地操作 (in-place operation)[^4],它会在原有张量的基础上直接将其从计算图中移除,并将自身的 `requires_grad` 属性设置为 `False`。由于它是原地操作,因此会对调用它的张量本身造成永久性的更改——即将其从计算图中完全剥离。 需要注意的是,一旦执行了 `detach_()`,原本指向同一块内存区域的所有变量都会受到影响,它们都将失去对原来计算图的连接关系。 具体表现可以通过下面的例子来说明: ```python import torch x = torch.tensor([2.0, 3.0], requires_grad=True) y = x * 3 y.detach_() # 将 y 自身从计算图中移除 print("y after detach_: ", y.requires_grad) # 输出 False try: y.sum().backward() except RuntimeError as e: print(e) # 报错提示:element 0 of tensors does not require grad and does not have a grad_fn ``` 在这里可以看到,在应用 `detach_()` 后再尝试进行反向传播时会发生错误,因为此时已经没有任何路径能够追溯回输入端以完成链式法则所需的导数累积过程。 --- #### 3. **两者的主要差异总结** | 特性 | `detach()` | `detach_()` | |-------------------------|-------------------------------------|-----------------------------------| | 是否改变原对象 | 否 | 是 | | 返回值 | 新的对象 | None | | 计算图状态 | 复制一份独立于现有计算图之外的新实例 | 当前对象被标记为不需要跟踪历史 | | 性能考虑 | 更安全(无副作用),适合复杂场景 | 效率更高但在简单调整中有风险 | 当决定采用哪种方式取决于实际应用场景以及个人偏好等因素综合考量之后才能做出最佳选择。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值