最近遇到一个头疼的事情,每次在模型部分加一些新的模块,就出现grad=nan的情况。仔细检查后发现是初始化参数的问题,torch自带的初始化模块(torch.nn.init)不知道为啥不起作用,而且数值非常大,有可能是没有适配bf16?
import torch
from timm.models.layers import trunc_normal_
class ...:
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.01, a=-1.0, b=1.0) # okay
# torch.nn.init.trunc_normal_ does not work here
try:
# nn.init.constant_(m.bias, 0) # does not work
m.bias.data.fill_(0) # okay
except:
pass
写在2024年11月