pytorch 在计算long和float运算时会出现错误:
>>> import torch
>>> a = torch.tensor([1,2,3], dtype=torch.long)
>>> a + 0.5
1
2
3
[torch.LongTensor of size (3,)]
这对应到在long类型数据使用dropout上时有时会出现如下问题:
/usr/local/lib/python3.6/dist-packages/torch/nn/_functions/dropout.py in forward(cls, ctx, input, p, train, inplace)
38 ctx.noise.fill_(0)
39 else:
---> 40 ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
41 ctx.noise = ctx.noise.expand_as(input)
42 output.mul_(ctx.noise)
RuntimeError: invalid argument 3: divide by zero at /pytorch/aten/src/THC/generic/THCTensorMathPairwise.cu:88
转成float类型可以暂时解决:
>>> a
tensor([ 100, 3, 100])
>>> a.float()
tensor([ 100., 3., 100.])
>>> d
Dropout(p=0.5)
>>> d(a.float())
tensor([ 200., 0., 0.]) # works fine
>>> d(a)
Floating point exception # throws error
暂时解决办法:
dropout_p = 0.1
dropout = torch.nn.Droupout(p=droupout_p)
if self.training:
x_ = dropout(inputs.float())
print(x_)
inputs = torch.round(x_.mul(1-dropout_p)).long() #round用来解决浮点数的误差问题
print(inputs)