import torch
x = [
torch.tensor(i, dtype=torch.float, requires_grad=True)
for i in range(1, 10)
]
with torch.no_grad():
y = map(lambda x: x * 2, x)
print(sum(y)) # tensor(90., grad_fn=<AddBackward0>)
虽然 y 是在 torch.no_grad 环境下创建的,但是由于 map 具有惰性求值的特性,所以 y 实际是在 sum(y) 被调用时才被计算出来。这时已经退出 torch.no_grad 环境,因此 sum(y) 有梯度。
在 PyTorch 中,torch.no_grad 主要用于禁止计算图的构建,以减少内存占用和计算开销。另一方面,Python 内置函数 map 具有惰性求值的特性,这可能会导致 torch.no_grad 失效,从而影响计算结果。
torch.no_grad
torch.no_grad 是 PyTorch 提供的一个上下文管理器,它的作用是在代码执行时禁止计算梯度。这样可以减少不必要的计算,提高执行效率。例如
import torch
torch.manual_seed(0)
x = torch.randn(3, requires_grad=True)
with torch.no_grad():
y = x * 2
print(y.requires_grad) # False
在 torch.no_grad() 作用域下计算 y = x * 2,由于梯度被禁用,y.requires_grad 为 False,即 PyTorch 不会为 y 构建计算图。
Python map 的惰性求值
在 Python 中,map 是一个惰性计算的函数,它返回一个迭代器,并不会立即执行计算。例如
def square(x: int) -> int:
print(f"Computing square of {x}")
return x * x
numbers = [1, 2, 3]
squares = map(square, numbers)
print(squares)
print("-" * 10)
print(list(squares))
执行上述代码,将会输出
<map object at 0x7f10fdf17730>
----------
Computing square of 1
Computing square of 2
Computing square of 3
[1, 4, 9]
可以看到,squares 并不是一个列表,而是一个 map 类型的对象。此外,Computing square of ... 的在 ---------- 之后被打印,证明 square 函数在 map(square, numbers) 中没有被执行,而是在 list(squares) 中才被执行的。这是因为 map 对象只有在被遍历(例如 list 或 sum)时才真正执行。
map 惰性求值导致 torch.no_grad 失效
结合 torch.no_grad 和 map,可能会出现意想不到的行为。例如
import torch
x = [
torch.tensor(i, dtype=torch.float, requires_grad=True)
for i in range(1, 10)
]
with torch.no_grad():
y = map(lambda x: x * 2, x)
print(sum(y).requires_grad) # True
为什么 y 仍然有梯度?
因为 map(lambda x: x * 2, x) 只创建了一个 map 对象,并没有真正执行 x * 2 计算。torch.no_grad() 作用域结束后,计算 sum(y) 时,x * 2 计算才真正发生。这时 torch.no_grad() 已经失效,因此 PyTorch 仍然会追踪梯度。
解决方案
如果希望 torch.no_grad 正确生效,可以强制 map 立即执行
with torch.no_grad():
y = list(map(lambda x: x * 2, x))
print(sum(y).requires_grad) # False
这样 y 在 torch.no_grad 作用域内就完成计算,确保 sum(y) 不会被追踪梯度。
6325

被折叠的 条评论
为什么被折叠?



