’torch.round后梯度为0,无法进行梯度回传‘的解决方法

文章介绍了在PyTorch中,由于torch.round函数导致梯度无法反向传播的问题,并提出了解决方案——通过自定义ste_round函数,利用torch.round(x)-x.detach()+x实现可微分的round操作,确保梯度的正确传递。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

torch.round后梯度为0,无法进行梯度回传的解决方法

问题描述:round函数在定义域中的导数,处处为0或者无穷,梯度无法反向传播。本文将使用autograd.function类自定义可微分的round函数,使得round前后的tensor,具有相同的梯度。

解决方法:

def ste_round(x):
    return torch.round(x) - x.detach() + x

解析:
torch.round(x)导数处处为0,x.detach()在计算图中无梯度,因此其ste_round的倒数就是x的导数。
即:
torch.round(x)导数处处为0,x.detach()在计算图中,x的导数为1

参考:

  1. Pytorch 可微分round函数
  2. pytorch中如何对tensor进行取整操作且梯度不变为0?
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值