ResNet18 加载预训练权重,训练过程出现 AssertionError: parameter contains nan

ResNet18 加载预训练权重,训练过程出现 AssertionError: parameter contains nan
原代码:

model = resnet18(weights=models.ResNet18_Weights.DEFAULT)  # 先加载未预训练的模型
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, args.num_classes)

这个错误通常是由于梯度爆炸或消失导致的。
使用凯明初始化,修改后的代码:

# 使用新的权重加载方式
    model = resnet18(weights=models.ResNet18_Weights.DEFAULT)  # 先加载未预训练的模型
    
    # 修改第一层卷积以适应CIFAR-10的32x32分辨率
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    
    # 使用Kaiming初始化
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)
    
    # 修改最后的全连接层
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, args.num_classes)
    
    # 将模型移动到指定设备
    model = model.to(device)
    
    # 验证模型参数没有NaN值
    for name, param in model.named_parameters():
        if torch.isnan(param).any():
            print(f"Warning: NaN detected in {name}")
            param.data.zero_()  # 将NaN值替换为0
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值