一、禁止计算局部梯度
torch.autogard.no_grad: 禁用梯度计算的上下文管理器。
当确定不会调用Tensor.backward()计算梯度时,设置禁止计算梯度会减少内存消耗。如果需要计算梯度设置Tensor.requires_grad=True
两种禁用方法:
- 将不用计算梯度的变量放在with torch.no_grad()里
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
Out[12]:False
- 使用装饰器 @torch.no_gard()修饰的函数,在调用时不允许计算梯度
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
Out[13]:False
二、禁止后允许计算局部梯度
torch.autogard.enable_grad :允许计算梯度的上下文管理器
在一个no_grad上下文中使能梯度计算。在no_grad外部此上下文管理器无影响.
用法和上面类似:
- 使用with torch.enable_grad()允许计算梯度