在编写好自己的 autograd function 后,可以利用
gradcheck
中提供的gradcheck
和gradgradcheck
接口,对数值算得的梯度和求导算得的梯度进行比较,以检查backward
是否编写正确。 即用它可以check自己写的反向传播函数是否正确这个函数可以自己计算通过数值法求得的梯度,然后和我们写的backward的结果比较
在下面的例子中,我们自己实现了
Sigmoid
函数,并利用gradcheck
来检查backward
的编写是否正确。import torch from torch.autograd import Function import torch class Sigmoid(Function): @staticmethod def forward(ctx, x): output = 1 / (1 + torch.exp(-x)) ctx.save_for_backward(output) return output @staticmethod def backward(ctx, grad_output): output, = ctx.saved_tensors grad_x = output * (1 - output) * grad_output return grad_x test_input = torch.randn(4, requires_grad=True) # tensor([-0.4646, -0.4403, 1.2525, -0.5953], requires_grad=True) print(torch.autograd.gradcheck(Sigmoid.apply, (test_input,), eps=1e-3)) # pass print(torch.autograd.gradcheck(torch.sigmoid, (test_input,), eps=1e-3)) # pass
返回True就代表是正确的
参考