利用PyTorch和机器学习掌握计算机视觉:WGAN与WGAN - GP实战
1. 梯度计算与惩罚示例
首先,我们来看一个梯度计算和梯度惩罚的示例代码。通过这段代码,我们可以了解如何使用 torch.autograd.grad 来计算梯度,并计算梯度惩罚。
import torch; import torch.nn as nn; from torch.autograd import grad
x = torch.FloatTensor([0, 0.5, 1.5, 2])
x.requires_grad=True
print('x =', x)
f = (x*x).sum()
# f = x[0]**2 + x[1]**2 + x[2]**2 + x[3]**2
print('f =', f) # f= tensor(6.5000)
df2dx = grad(outputs=f,
inputs=x,
grad_outputs=torch.ones_like(f),
create_graph=True,
retain_graph=True)[0]
print('df2dx = ', df2dx)
#df2dx= tensor([0., 1., 3., 4.])
grad_norm = df2dx.norm(p=2, dim=0)
print('grad_norm=', grad_norm)
#grad_norm= tensor(5.1)
gradient_penalty = 10*torch.mean
超级会员免费看
订阅专栏 解锁全文
7743

被折叠的 条评论
为什么被折叠?



