userwarning描述:/pytorch/aten/src/ATen/native/cuda/LegacyDefinitions.cpp:14: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead.
解决方法:
将
masked_fill_(data)
改成
masked_fill_(data.bool())
即可
本文解决了PyTorch中masked_fill_函数使用uint8类型mask的警告问题,建议改为bool类型,避免未来版本的兼容性问题。
1258

被折叠的 条评论
为什么被折叠?



