即插即拔检查Pytorch训练时出现nan的问题

import torch


#******************************检测输入输出模型参数是否有NAN inf 可以即插即拔,放在原来的代码********************************************************
def hook_function(module, input, output):
    # print(f"Module name: {module.__class__.__name__}")
    try:
        if isinstance(input, tuple) :
            for in_ in input:
                if isinstance(in_, torch.Tensor):

                    if torch.isnan(in_).any():
                        print(f"input shape: {in_.shape}")
                        print(f"NaN detected before layer: {module.__class__.__name__}")
        else: 
            # print(f"input shape: {input.shape}")
            # print("input stats - mean: {}, std: {}, min: {}, max: {}".format(
            #     input.mean().item(), input.std().item(), input.min().item(), input.max().item()))
            if torch.isnan(input).any():
                print(f"input shape: {input.shape}")
                print(f"NaN detected before layer: {module.__class__.__name__}")
    except:
        pass
    
    
    try:
        if isinstance(output, tuple) :
            for out in output:
                if isinstance(out, torch.Tensor):

                    if torch.isnan(out).any():
                        print(f"Output shape: {out.shape}")
                        print(f"NaN detected after layer: {module.__class__.__name__}")
        else: 
            # print(f"Output shape: {output.shape}")
            # print("Output stats - mean: {}, std: {}, min: {}, max: {}".format(
            #     output.mean().item(), output.std().item(), output.min().item(), output.max().item()))
            if torch.isnan(output).any():
                print(f"Output shape: {out.shape}")
                print(f"NaN detected after layer: {module.__class__.__name__}")
    except:
        pass
    
    for name, param in module.named_parameters():
        if torch.isnan(param).any() or torch.isinf(param).any():
            print(f"NaN detected in layer param: {module.__class__.__name__}")
            print(f"Parameter {name} has NaN or Inf values")

        
class ModelWrapper:
    def __init__(self, model):
        self.model = model
        self.hooks = []

    def register_hooks(self):
        # 遍历预训练模型的所有模块,为它们注册前向钩子
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Module):  # 确保是模块,而非其他(如参数)
                self.hooks.append(module.register_forward_hook(hook_function))

    def remove_hooks(self):
        # 移除之前注册的所有钩子
        for hook in self.hooks:
            hook.remove()

wrapper = ModelWrapper(model)

# 注册钩子
wrapper.register_hooks()

#**************************************************************************************

上面的代码放到训练时,创建了model之后即可。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值