Pytorch多卡训练后的模型在使用过程中的一些坑

本文探讨在PyTorch中利用多GPU环境进行模型训练,如何通过DataParallel实现并行计算loss的求平均,以及在模型保存和加载时的注意事项。特别介绍了针对DataParallel模型的保存策略。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

if torch.cuda.device_count() > 1:
     print("Let's use", torch.cuda.device_count(), "GPUs!")
     model = torch.nn.DataParallel(model)

多GPU环境计算loss

解决方法:
将多卡得到的loss进行mean,求平均:

loss_avg.backward()  -----> loss_avg.mean().backward()

多GPU环境下对模型的保存:

#万能的保存方法,如果你的预测函数不会依赖你的模型类定义。
if isinstance(model, torch.nn.DataParallel):
torch.save(model.module.state_dict(), config.save_path)
#依赖你的模型定义,这么保存只保存了你模型的参数,模型的结构没有保存,所以尽量用上面的保存方法。
if isinstance(model, torch.nn.DataParallel):
torch.save(model.state_dict(), config.save_path)

多GPU环境下对模型的加载去预测:

model = torch.load(‘path/to/model’)
if isinstance(model,torch.nn.DataParallel):
model = model.module
#下面就可以正常使用了
model.eval()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值