import torch
#******************************检测输入输出模型参数是否有NAN inf 可以即插即拔,放在原来的代码********************************************************
def hook_function(module, input, output):
# print(f"Module name: {module.__class__.__name__}")
try:
if isinstance(input, tuple) :
for in_ in input:
if isinstance(in_, torch.Tensor):
if torch.isnan(in_).any():
print(f"input shape: {in_.shape}")
print(f"NaN detected before layer: {module.__class__.__name__}")
else:
# print(f"input shape: {input.shape}")
# print("input stats - mean: {}, std: {}, min: {}, max: {}".format(
# input.mean().item(), input.std().item(), input.min().item(), input.max().item()))
if torch.isnan(input).any():
print(f"input shape: {input.shape}")
print(f"NaN detected before layer: {module.__class__.__name__}")
except:
pass
try:
if isinstance(output, tuple) :
for out in output:
if isinstance(out, torch.Tensor):
if torch.isnan(out).any():
print(f"Output shape: {out.shape}")
print(f"NaN detected after layer: {module.__class__.__name__}")
else:
# print(f"Output shape: {output.shape}")
# print("Output stats - mean: {}, std: {}, min: {}, max: {}".format(
# output.mean().item(), output.std().item(), output.min().item(), output.max().item()))
if torch.isnan(output).any():
print(f"Output shape: {out.shape}")
print(f"NaN detected after layer: {module.__class__.__name__}")
except:
pass
for name, param in module.named_parameters():
if torch.isnan(param).any() or torch.isinf(param).any():
print(f"NaN detected in layer param: {module.__class__.__name__}")
print(f"Parameter {name} has NaN or Inf values")
class ModelWrapper:
def __init__(self, model):
self.model = model
self.hooks = []
def register_hooks(self):
# 遍历预训练模型的所有模块,为它们注册前向钩子
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Module): # 确保是模块,而非其他(如参数)
self.hooks.append(module.register_forward_hook(hook_function))
def remove_hooks(self):
# 移除之前注册的所有钩子
for hook in self.hooks:
hook.remove()
wrapper = ModelWrapper(model)
# 注册钩子
wrapper.register_hooks()
#**************************************************************************************
上面的代码放到训练时,创建了model之后即可。