pytorch cpu与gpu load时相互转化 torch.load(map_location=)

本文详细介绍了如何在PyTorch中将模型参数从一个设备加载到另一个设备,包括从CPU到GPU、从GPU到CPU以及在不同GPU间进行迁移的方法。同时,针对在特定环境下可能遇到的运行时错误提供了解决方案。
部署运行你感兴趣的模型镜像

pytorch将cpu训练好的模型参数load到gpu上,或者gpu->cpu上

假设我们只保存了模型的参数(model.state_dict())到文件名为modelparameters.pth, model = Net()

1. cpu -> cpu或者gpu -> gpu:

checkpoint = torch.load('modelparameters.pth')

model.load_state_dict(checkpoint)

2. cpu -> gpu 1

torch.load('modelparameters.pth', map_location=lambda storage, loc: storage.cuda(1))

3. gpu 1 -> gpu 0

torch.load('modelparameters.pth', map_location={'cuda:1':'cuda:0'})

4. gpu -> cpu

torch.load('modelparameters.pth', map_location=lambda storage, loc: storage)

或者:

如果出现这个报错

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False.

If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.

torch.load(opt.model,map_location='cpu')

 

方便查阅

您可能感兴趣的与本文相关的镜像

PyTorch 2.7

PyTorch 2.7

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

在使用 PyTorch 加载模型检查点(checkpoint),若需要将模型参数映射到 CPU,可以通过 `torch.load` 函数中的 `map_location` 参数实现。该参数允许指定加载张量的目标设备,使用 `'cpu'` 字符串即可将模型参数加载到 CPU 上,适用于没有 GPU 或希望在 CPU 上运行模型的场景。 例如,当模型检查点保存的是通过 `torch.save(model.state_dict(), PATH)` 存储的状态字典,可以通过以下方式加载: ```python checkpoint = torch.load('model_checkpoint.pth', map_location='cpu') model.load_state_dict(checkpoint) ``` 此方法确保所有张量都被加载到 CPU 设备上,无需额外调整代码以适应 GPU 环境[^3]。 ### 加载模型检查点的设备映射 `map_location` 参数不仅接受字符串形式的设备标识,还支持 `torch.device` 对象。这意味着可以更精确地控制加载行为,例如: ```python device = torch.device('cpu') checkpoint = torch.load('model_checkpoint.pth', map_location=device) model.load_state_dict(checkpoint) ``` 这种灵活性使得开发者能够在不同设备上加载模型,而无需担心模型保存所使用的设备环境[^1]。 ### 多设备环境下的加载 在多设备环境下,如果模型原本是在 GPU 上训练并保存的,使用 `map_location='cpu'` 可以避免因设备不匹配导致的错误。这种方式特别适用于模型部署或测试阶段,尤其是在没有 GPU 的环境中运行模型。 ### 加载包含额外文件的检查点 如果模型检查点包含额外的文件(如配置文件、优化器状态等),可以使用 `_extra_files` 参数加载这些内容。例如: ```python extra_files = {'config.txt': ''} checkpoint = torch.load('model_checkpoint.pth', map_location='cpu', _extra_files=extra_files) print(extra_files['config.txt']) # 输出额外文件的内容 ``` 此功能允许在加载模型的同处理附加数据,为模型的配置和状态恢复提供便利。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值