torch.round后梯度为0,无法进行梯度回传的解决方法
问题描述:round函数在定义域中的导数,处处为0或者无穷,梯度无法反向传播。本文将使用autograd.function类自定义可微分的round函数,使得round前后的tensor,具有相同的梯度。
解决方法:
def ste_round(x):
return torch.round(x) - x.detach() + x
解析:
torch.round(x)导数处处为0,x.detach()在计算图中无梯度,因此其ste_round的倒数就是x的导数。
即:
torch.round(x)导数处处为0,x.detach()在计算图中,x的导数为1
参考: