需要注意一点,代码需要写在 backward() 前面,因为钩子函数是在 backward 时执行的。
var.register_hook(lambda grad: torch.where(torch.isnan(grad), torch.ones_like(grad), grad))
total_loss.backward()
本文解释了在PyTorch中,为了处理计算过程中可能出现的NaN梯度,如何在`backward()`函数之前使用`torch.where()`和自定义hook函数`lambdagrad`。
需要注意一点,代码需要写在 backward() 前面,因为钩子函数是在 backward 时执行的。
var.register_hook(lambda grad: torch.where(torch.isnan(grad), torch.ones_like(grad), grad))
total_loss.backward()

被折叠的 条评论
为什么被折叠?