pytorch Dropout错误 noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p) divide by zero

本文探讨了PyTorch框架中long和float类型数据运算时可能出现的错误,特别是在使用dropout功能时遇到的问题。文章详细介绍了错误产生的原因,并提供了一种将long类型转换为float类型的解决方案。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

pytorch 在计算long和float运算时会出现错误:

>>> import torch
>>> a = torch.tensor([1,2,3], dtype=torch.long)
>>> a + 0.5

 1
 2
 3
[torch.LongTensor of size (3,)]

这对应到在long类型数据使用dropout上时有时会出现如下问题:

/usr/local/lib/python3.6/dist-packages/torch/nn/_functions/dropout.py in forward(cls, ctx, input, p, train, inplace)
     38             ctx.noise.fill_(0)
     39         else:
---> 40             ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
     41         ctx.noise = ctx.noise.expand_as(input)
     42         output.mul_(ctx.noise)

RuntimeError: invalid argument 3: divide by zero at /pytorch/aten/src/THC/generic/THCTensorMathPairwise.cu:88

转成float类型可以暂时解决:

>>> a
tensor([ 100,    3,  100])
>>> a.float()
tensor([ 100.,    3.,  100.])
>>> d
Dropout(p=0.5)
>>> d(a.float())
tensor([ 200.,    0.,    0.]) # works fine
>>> d(a)
Floating point exception # throws error

暂时解决办法:

dropout_p = 0.1
dropout = torch.nn.Droupout(p=droupout_p)
if self.training:
    x_ = dropout(inputs.float())
    print(x_)
    inputs = torch.round(x_.mul(1-dropout_p)).long()  #round用来解决浮点数的误差问题 
    print(inputs)

 

参考自

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值