round函数在定义域中的导数,处处为0或者无穷,梯度无法反向传播。本文将使用autograd.function类自定义可微分的round函数,使得round前后的tensor,具有相同的梯度。
from torch.autograd import Function
class BypassRound(Function):
@staticmethod
def forward(ctx, inputs):
return torch.round(inputs)
@staticmethod
def backward(ctx, grad_output):
# 这里的grad_output是round之后的tensor的梯度,直接将它作为round之前tensor的梯度
return grad_output
# Function.apply的别名
bypass_round = BypassRound.apply
# demo
z3_rounded = bypass_round(z3)
具体原理和细节参考以下博客:
定义torch.autograd.Function的子类,自己定义某些操作,且定义反向求导函数_tsq292978891的博客-优快云博客_saved_tensors
2022.4.7更新:更简单的方法如下
def ste_round(x):
return torch.round(x) - x.detach() + x
torch.round(x)导数处处为0,x.detach()在计算图中,x的导数为1
因此:ste_round(x)的梯度 == x的梯度