ResNet18 加载预训练权重,训练过程出现 AssertionError: parameter contains nan
原代码:
model = resnet18(weights=models.ResNet18_Weights.DEFAULT) # 先加载未预训练的模型
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, args.num_classes)
这个错误通常是由于梯度爆炸或消失导致的。
使用凯明初始化,修改后的代码:
# 使用新的权重加载方式
model = resnet18(weights=models.ResNet18_Weights.DEFAULT) # 先加载未预训练的模型
# 修改第一层卷积以适应CIFAR-10的32x32分辨率
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
# 使用Kaiming初始化
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
# 修改最后的全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, args.num_classes)
# 将模型移动到指定设备
model = model.to(device)
# 验证模型参数没有NaN值
for name, param in model.named_parameters():
if torch.isnan(param).any():
print(f"Warning: NaN detected in {name}")
param.data.zero_() # 将NaN值替换为0