PyTorch 张量和模型的保存与加载

PyTorch支持将张量和模型保存为.pt、.pth、.pkl格式,张量可以通过torch.save()和torch.load()进行保存和加载,包括单个张量、多个张量以及张量关系的保存。模型的保存包括整个模型结构和仅保存state_dict,后者更推荐,但需要预先知道模型结构。加载时,可能需要将模型移到特定设备,如CPU或GPU。此外,保存模型时还会涉及优化器的状态,以便断电续训。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

PyTorch 张量和模型的保存与加载

PyTorch 中 Tensor 和模型可以保存为 .pt, .pth, .pkl 格式的文件,三者在格式上没有区别

张量的保存和加载

单个张量

x = torch.randint(100, (4, 4))
torch.save(x, './save_model/x.pt')  # 保存

y = torch.load('./save_model/x.pt')  # 加载
print(y)

多个张量

x = torch.randint(10, (2, 2))
y = torch.randint(100, (3, 3))

a = [x, y]
torch.save(a, './save_model/a.pt')
a_ = torch.load('./save_model/a.pt')

b = (x, y)
torch.save(b, './save_model/b.pt')
b_ = torch.load('./save_model/b.pt')

c = {'x':x, 'y':y}
torch.save(c, './save_model/c.pt')
c_ = torch.load('./save_model/c.pt')

torch.save() 会保留张量之间的关系。若只想保留张量的数据,不保留张量之间的关系,可以使用clone()来保存张量的数据克隆。

x = torch.arange(10)
y = x[:5]
torch.save([x, y.clone()], './save_model/x_and_y.pt')
x_, y_ = torch.load('./save_model/x_and_y.pt')

模型的保存和加载

保存模型

这种方式不仅保存、加载模型的数据,也包括模型的结构一并存储,存储的文件会较大,好处是加载时不需要提前知道模型的结构,解来即用。实际上这与上文提到的保存Tensor是一致的。

torch.save(model, save_path)  # 保存
model = torch.load(save_path)  # 加载
model.eval()  # 如果加载的模型要用于 inference,则需要使用 model.eval()

保存模型参数

这是一种较为推荐的保存方法,即只保存模型的参数,保存的模型文件会较小,而且比较灵活。但是当加载时,需要先实例化一个模型,然后通过加载将参数赋给这个模型的实例,也就是说加载之前使用者需要知道模型的结构

torch.save(model.state_dict(), save_path)  # 保存
model = Model()  # 实例化
model.load_state_dict(torch.load(save_path))  # 加载
model.eval()

比较重要的点是:

  • 保存模型时调用 state_dict() 获取模型的参数,而不保存结构
  • 加载模型时需要预先实例化一个对应的网络,比如model = Model(),这也就意味着,使用者需要预先有Model这个类,如果不知道这个网络的类定义或者结构,这种只保存参数的方法将无法使用
  • 加载模型使用 load_state_dict 方法,其参数不是文件路径,而是 torch.load(PATH)
  • 如果加载出来的模型用于验证,不要忘了使用 model.eval() 方法,它会丢弃 dropout、normalization 等层,因为这些层不能在inference 的时候使用,否则得到的推断结果不一致

保存与加载模型其他参数

有时我们不仅要保存模型,还要连带保存一些其他的信息。比如在训练过程中保存一些 checkpoint,往往除了模型,还要保存它的epoch、loss、optimizer等信息,以便于加载后对这些 checkpoint 继续训练等操作;或者再比如,有时候需要将多个模型一起打包保存等。这些其实也很简单,正如我们上文提到的,torch.save 可以保存dict、list、tuple等多种数据结构,所以一个字典可以很完美的解决这个问题,比如一个简单的例子:

# saving
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

# loading
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

跨设备存储与加载

  • 从 CPU 保存,加载到 CPU

实际上,这就是默认的情况,我们上文提到的所有内容都没有关心设备的问题,因此也就适应于这种情况。

  • 从 CPU 保存,加载到 GPU
torch.save(net.state_dict(), save_path)  # 保存

device = torch.device('cuda: 0')
net = Net()
net.load_state_dict(torch.load(save_path, map_location=device))  # 通过 map_loacation 映射到 GPU
# or
net.load_state_dict(torch.load(save_path))  # 加载
net.to(device)  # 使用 .to() 发送到 GPU
  • 从 GPU 保存,加载到 CPU
torch.save(net.state_dict(), save_path)  # 保存
device = torch.device('cpu')
net = Net()  # 实例化
net.load_state_dict(torch.load(save_path, map_location=device))
  • 从 GPU 保存,加载到 GPU
torch.save(net.state_dict(), save_path)  # 保存
device = torch.device('cuda: 0')  # 指定设备
net = Net()  # 实例化
net.to(device)  # 使用 .to() 发送到 GPU

断电续训

  • 保存
torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'loss': loss,
            'acc': acc}, save_path)
  • 重新加载
model = Model()  # 实例化
optimizer = torch.optim.SGD(*args, **kwargs)
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
acc = checkpoint['acc']

model.train()  # 训练
# or
model.eval()  # 验证

在保存用于推理或者继续训练的常规检查点的时候,除了模型的 state_dict 之外,还必须保存其他参数。保存优化器的 state_dict 也非常重要,因为它包含了模型在训练时候优化器的缓存和参数。除此之外,还可以保存停止训练时 epoch 数,最新的模型损失,额外的 torch.nn.Embedding 层等。

要保存多个组件,则将它们放到一个字典中,然后使用 torch.save() 序列化这个字典。一般来说,使用 .tar 文件格式来保存这些检查点。

加载各个组件,首先初始化模型和优化器,然后使用 torch.load() 加载保存的字典,然后可以直接查询字典中的值来获取保存的组件。

同样,评估模型的时候一定不要忘了调用 model.eval()

参考资料

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值