pytorch中detach()函数和requires_grad_(False)解释。

部署运行你感兴趣的模型镜像

一、detach()requires_grad_(False) 解释

detach()requires_grad_(False) 函数是 PyTorch 中用于解除计算图连接和停止梯度计算的函数。

detach(): 将当前 tensor 从计算图中分离出来,返回一个新的 tensor,新 tensor 的 requires_grad 属性会被设为 False。也就是说,调用 detach() 后,将无法再通过这个 tensor 得到梯度信息,即使后续计算的结果与该 tensor 相关。
requires_grad_(False): 在原地修改 tensor 的 requires_grad 属性为 False,这个 tensor 之前的梯度信息将会被清除。
这两个函数在实际使用中常常用于以下几个方面:

1、在使用 Tensor 进行运算时,可能会产生一些临时结果,这些结果并不需要用于梯度计算,因此需要将这些 tensor 分离出计算图,以减少计算开销和内存消耗。

2、当需要对某些 tensor 进行修改,但不希望这些修改对之前的梯度计算产生影响时,可以使用 detach() 函数,以避免不必要的计算。

3、在使用预训练模型进行微调时,需要冻结一部分参数,不参与梯度更新。此时可以将这些参数的 requires_grad 属性设为 False,避免对其进行不必要的梯度计算。
二、例子

import torch

# 定义需要进行梯度计算的 tensor
x = torch.randn(3, 3, requires_grad=True)
w = torch.randn(3, 3, requires_grad=True)

# 计算 x 和 w 的点积
y = torch.matmul(x, w)

# 对 y 进行一些操作,但不需要计算梯度
z = y.detach().sigmoid()

# 计算 z 的平均值,并反向传播梯度
loss = z.mean()
loss.backward()

# 冻结 w 的梯度,只更新 x 的梯度
w.requires_grad_(False)
x.grad.zero_()
x2 = torch.randn(3, 3, requires_grad=True)
y2 = torch.matmul(x2, w)
loss2 = y2.mean()
loss2.backward()

# 打印梯度信息
print("x 的梯度:", x.grad)
print("w 的梯度:", w.grad)
print("x2 的梯度:", x2.grad)

在这个示例中,我们首先定义了两个需要进行梯度计算的 tensor:x 和 w。然后计算了它们的点积 y,并对 y 进行了一个 sigmoid 操作,得到了 z。由于我们不希望对 z 进行梯度计算,所以使用 detach() 函数将其从计算图中分离出来。接着,我们计算了 z 的平均值 loss,并进行了反向传播。这样就得到了 x 和 w 的梯度。

然后,我们冻结了 w 的梯度,只更新了 x 的梯度。为了说明这一点,我们定义了另外一个 tensor x2,将 x2 和 w 做点积,得到 y2。然后计算了 y2 的平均值 loss2,并进行了反向传播。由于我们冻结了 w 的梯度,所以只有 x2 的梯度被更新,w 的梯度仍然是零。最后打印了三个 tensor 的梯度信息。

您可能感兴趣的与本文相关的镜像

PyTorch 2.8

PyTorch 2.8

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

这段代码展示了PyTorch中梯度跟踪的控制机制,特别是`torch.no_grad()`的作用。以下是代码的详细解释运行结果: ### 代码解释 1. **创建张量**: ```python x = torch.tensor(1.0, requires_grad=True) ``` - 创建一个标量张量`x`,值为`1.0`,并启用梯度跟踪(`requires_grad=True`)。 2. **计算`y1`**: ```python y1 = x ** 2 ``` - `y1 = x^2`,由于`x`启用了梯度跟踪,`y1`也会保留梯度信息(`requires_grad=True`)。 3. **在`no_grad`上下文中计算`y2`**: ```python with torch.no_grad(): y2 = x ** 3 ``` - `torch.no_grad()`会临时禁用梯度跟踪,因此`y2 = x^3`的计算不会被记录在计算图中,`y2`的`requires_grad`为`False`。 4. **计算`y3`**: ```python y3 = y1 + y2 ``` - `y3`是`y1``y2`的。由于`y1`需要梯度而`y2`不需要,PyTorch会保留`y3`的梯度跟踪(`requires_grad=True`),但`y2`的部分不会被计算梯度。 5. **打印结果**: ```python print(x.requires_grad) # True print(y1, y1.requires_grad) # True print(y2, y2.requires_grad) # False print(y3, y3.requires_grad) # True ``` ### 运行结果 ``` True tensor(1., grad_fn=<PowBackward0>) True tensor(1.) False tensor(2., grad_fn=<ThAddBackward>) True ``` - `x.requires_grad`为`True`,因为初始化时启用了梯度跟踪。 - `y1`的值为`1.0`(因为`1.0^2 = 1.0`),`requires_grad=True`。 - `y2`的值为`1.0`(因为`1.0^3 = 1.0`),但`requires_grad=False`(因为在`no_grad`上下文中计算)。 - `y3`的值为`2.0`(因为`1.0 + 1.0 = 2.0`),`requires_grad=True`(因为`y1`需要梯度)。 ### 补充说明 1. **`torch.no_grad()`的作用**: - 禁用梯度跟踪,减少内存使用并加速计算(常用于推理阶段)。 - 在`no_grad`上下文中创建的张量不会参与梯度计算。 2. **`y3`的梯度计算**: - 如果对`y3`调用`backward()`,只会计算`y1`对`x`的梯度(`dy1/dx = 2x`),而`y2`的部分会被忽略。 - 示例: ```python y3.backward() print(x.grad) # 输出 tensor(2.)(因为 dy3/dx = dy1/dx = 2x = 2.0) ``` 3. **手动控制梯度**: - 如果需要`y2`也参与梯度计算,可以去掉`no_grad`上下文: ```python y2 = x ** 3 # requires_grad=True y3 = y1 + y2 y3.backward() print(x.grad) # 输出 tensor(5.)(因为 dy3/dx = dy1/dx + dy2/dx = 2x + 3x^2 = 2.0 + 3.0 = 5.0) ``` ### 常见问题 1. `torch.no_grad()``detach()`的区别是什么? - `no_grad()`是上下文管理器,临时禁用梯度跟踪。 - `detach()`返回一个新张量,永久脱离计算图。 2. 如果对`y3`调用`backward()`,`y2`的梯度会如何? - 不会计算`y2`的梯度,因为`y2`是在`no_grad`上下文中计算的。 3. 如何检查张量是否参与梯度计算? - 使用`requires_grad`属性(如`y2.requires_grad`)。 4. 如果`y2`需要梯度但`y1`不需要,该如何实现? - 对`y1`使用`detach()`或`no_grad`,对`y2`正常计算。 5. `y3.grad_fn`的值是什么? - `y3.grad_fn`是`<ThAddBackward>`,表示`y3`是通过加法操作得到的。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值