dim 必须为1, 这样才能满足softmax每一行相加为一。
import torch.nn.functional as F
# 4 samples, 2 output classes
logits = torch.randn(4, 2)
print(F.softmax(logits, dim=1))
tensor([[0.7018, 0.2982],
[0.9550, 0.0450],
[0.4557, 0.5443],
[0.8057, 0.1943]])
for item in F.softmax(logits, dim=1):
print(i.sum())
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)