目录
方法 1:设置 weights_only=False(不安全)
问题描述
最近在学习b站up:小土堆的PyTorch教程。
在“完整的模型验证套路”一节中遇到了WeightsUnpickler error: Unsupported global: GLOBAL __main__.Model was not an allowed global by default的问题:
通过解读错误信息得知,这与PyTorch 中的 torch.load
函数有关,具体是在尝试加载模型检查点文件(model_0.pth
)时发生的。错误信息表明,从 PyTorch 2.6 开始,torch.load 的 weights_only 参数默认设置为 True,这是为了防止加载不受信任的文件时执行任意代码。这个设置限制了可以加载的对象类型,只允许加载张量、字典和张量列表等安全对象。
错误原因
-
weights_only=True
:这个设置是为了安全考虑,防止加载可能包含恶意代码的文件。如果你加载的文件中包含自定义类(例如__main__.Model
),默认情况下会被拒绝。 -
不支持的自定义类:错误信息中提到
GLOBAL __main__.Model
是一个不被允许的全局对象,因为weights_only=True
不允许加载自定义类。
解决方法
有三种解决方法:
方法 1:设置 weights_only=False(不安全)
如果信任模型文件(模型)的来源,可以通过将 weights_only
参数设置为 False
来加载模型。这会允许加载自定义类,但需要注意潜在的安全风险。
import torch
# 设置 weights_only=False
model = torch.load('model_0.pth', weights_only=False)
注意:只有在确定文件来源可信的情况下才使用此方法,否则可能会导致任意代码执行。
方法 2:使用 add_safe_globals(安全)
如果你希望继续使用 weights_only=True
,可以通过将自定义类添加到安全全局对象列表中来解决。具体步骤如下:
torch.serialization.add_safe_globals(
[Tudui,Sequential,Conv2d,MaxPool2d,Flatten,Linear])
model = torch.load('model_0.pth')
PS:我之所以只写类名是因为在导入时是这样写的:
from torch.nn import Module,Sequential,Conv2d,MaxPool2d,Flatten,Linear
如果你导入用的是import torch.nn,那么上边的代码要换为:
torch.serialization.add_safe_globals(
[Tudui,nn.Sequential,nn.Conv2d,nn.MaxPool2d,nn.Flatten,nn.Linear])
model = torch.load('model_0.pth')
如果你不确定add_safe_globals
里需要添加哪些类,可以通过以下步骤逐步调试:
import torch
try:
model = torch.load('model_0.pth', weights_only=True)
except pickle.UnpicklingError as e:
print(f"Error: {e}")
运行代码后,错误信息会提示哪些类需要被允许。例如,如果提示 Please use `torch.serialization.add_safe_globals([Linear])`:
那么就需要将Linear 添加到列表中。
方法3:重新训练模型(推荐)
在训练模型的代码中,使用如下方式训练并保存模型:
torch.save(model.state_dict(), f"model_{i}.pth")
然后再次训练模型,并通过如下方式加载模型(替换原来的model = torch.load('model_0.pth')):
model = Model()
model.load_state_dict(torch.load('model_0.pth'))
可能遇到的其他问题
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
这个错误表明模型权重是存储在 GPU 上的(torch.cuda.FloatTensor
),而输入数据是在 CPU 上的(torch.FloatTensor
)。PyTorch 要求模型和输入数据必须在同一设备上(要么都在 CPU 上,要么都在 GPU 上)。
最简单的解决办法:用cuda(GPU)加载输入数据即可:
with torch.no_grad():
output = model(img.cuda())
或者也可以通过方法3来解决。