https://github.com/pytorch/pytorch/issues/81027
使用Categorical分布的时候遇到报错:ValueError: Expected parameter probs (Tensor of shape (64, 32, 8, 16)) of distribution OneHotCategoricalStraightThrough() to satisfy the constraint Simplex(), but found invalid values:
解决办法:
1. 转换类型
因为我用的是bfloat16,sum起来不等于1,转换类型为float就好了
2. 设置参数
设置参数 validate_args=False 就可以忽略对sum=1的检查
704

被折叠的 条评论
为什么被折叠?



