模型训练一开始好好的,前几轮也好好的,突然,loss出现了nan,我的模型准确率本来都很高了,啪地跌到了0.
排查1:可能是数据集有问题。经过更换100%没问题的数据集之后,nan问题仍然出现
排查2:通过调试,发现计算loss之前的模型输出全是nan,啊啊啊
网上有说lr过大导致参数nan,有除0 log0 等无效计算导致nan,也有说梯度爆炸导致更新模型参数有了nan。
我觉得我的lr没问题,也没有无效计算的情况,所以我猜测是梯度爆炸。我尝试了如下的代码,放在更新参数之前。意外之喜是 它能使得我的模型加速收敛,准确率噌噌噌提高!!!
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
but,nan依旧!!
排查3:我只能用下面的代码+断点排查nan在哪
if torch.isnan(grasp_cos_out).any() or torch.isinf(grasp_cos_out).any():
assert True
奈何复现一次nan出现需要等很久,我很难受。于是我去学习一位老哥的钩子函数方法来排查。形成的成果见我的上一篇文章。
扑朔迷离的情况:我发现在一个卷积输出全是nan,是全是nan,而输入是"正常的"(后来发现,应该是混有了nan), 我百思不得其解,卷积不就是对应相乘相加,也不可能出现nan啊。
我看了看输入输出dtype,是torch.float16! 精度范围只有1e-8 最大最小值只有6.5w。我怀疑会不会是卷积相乘相加溢出了?我又看了一眼nan出现时的输入,它几乎都是0,难道是计算结果太接近0,然后下溢了?
经过测试,发现计算结果只可能趋于0,而不会到+-6.5w,所以不是上溢出,下溢出是小数直接没有了,变为0,然而这是成了nan。
我寻思到底是不是torch.float16的问题。我想方设法将卷积计算那一块 变成torch.float32进行计算,发现nan依旧。啊啊啊!!!。我干脆把amp.auto_cast()注释掉,哈模型显存占用量直接翻倍,计算速度也慢了很多。但是!!nan一直都没有出现!!!
最后的曙光:经过关闭amp,我认为很可能就是torch.float16的问题!!!但是把卷积那里修改成32类型,nan依旧出现。
我认真想了想之前的钩子代码排查出来的nan出现的地方:包括BatchNorm、ReLU。BatchNorm要求mena和方差,要求和,很可能会上溢出。我发现确实中间的输入有10这个级别的数字,我估计就是BatchNorm没准了。
欧耶!我猜中了!BatchNorm换成了32类型进行计算,没有问题!!!!!!!!
总结:
问题根源:amp混合精度计算,使用fp16的精度计算batch_norm 导致nan
但是后面 batch_norm还是nan了,于是我只能删掉这个batch_norm了,然后就,没有问题了!!!