userwarning描述:/pytorch/aten/src/ATen/native/cuda/LegacyDefinitions.cpp:14: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead.
解决方法:
将
masked_fill_(data)
改成
masked_fill_(data.bool())
即可