一、背景
pytorch是深度学习的重要建模工具,也几乎是人工智能工程师必须掌握的一大法宝。在建模的过程中,pytorch提供了很多高效的函数以及方法便于我们对数据、模型进行处理。为了避免重复造轮子,我们来盘点下pytorch常用的函数及方法。
二、常见函数及方法
1、张量操作
torch.tensor(data)
:创建一个新的张量。tensor.numpy()
:将PyTorch张量转换为NumPy数组。tensor.detach()
:返回一个新的张量,与当前计算图断开连接,不会在梯度传播中跟踪。例如,在模型评估阶段,我们需要对模型的输出进行一些操作,但不希望这些操作影响模型的梯度时,使用detach()
可以确保评估过程中的计算不会影响训练过程。此外,当我们需要将张量从一个设备(如GPU)移动到另一个设备(如CPU)时,detach()
可以避免不必要的梯度计算。
import torch
# 创建一个需要梯度的张量
x = torch.randn(2, 2, requires_grad=True)
# 执行一些操作
y = x * x
# 将y从计算图中分离
z = y.detach()
z.requires_grad_()
# 对z进行操作不会影响x的梯度
z.sum().backward() # 这不会在x上产生梯度
# 检查x的梯度
print(x.grad) # None,因为没有对x进行需要梯度的计算
tensor.requires_grad_()
:设置张量,使其在计算图中跟踪梯度。
import torch
# 创建一个张量,默认不跟踪梯度
x = torch.randn(2, 2)
print(x.requires_grad) # 输出: False
# 开启梯度跟踪
x.requires_grad_() # 现在x将跟踪梯度
print(x.requires_grad) # 输出: True
tensor.grad
:获取或设置张量的梯度。torch.no_grad()
ÿ