>>> label=torch.tensor([[0,1,6,4],[3,3,0,2]])
>>> pred=torch.randn(3,2,4)
>>> pred_1=pred.permute(1,2,0)
>>> mask=label<3
>>> mask
tensor([[ 1, 1, 0, 0],
[ 0, 0, 1, 1]], dtype=torch.uint8)
>>> target=label[mask]
>>> target
tensor([ 0, 1, 0, 2])
>>> pred
tensor([[[-0.2948, -2.7751, 1.0326, -0.5361],
[ 0.7725, -0.5589, -1.0469, 0.4892]],
[[-0.8348, 3.5035, -0.7435, 0.7182],
[-1.1817, -0.5274, 2.1227, -0.8889]],
[[ 0.4989, -1.4127, 0.6227, 2.1131],
[-0.0496, -0.4677, 0.7483, -2.2053]]])
>>> pred_1
tensor([[[-0.2948, -0.8348, 0.4989],
[-2.7751, 3.5035, -1.4127],
[ 1.0326, -0.7435, 0.6227],
[-0.5361, 0.7182, 2.1131]],
[[ 0.7725, -1.1817, -0.0496],
[-0.5589, -0.5274, -0.4677],
[-1.0469, 2.1227, 0.7483],
[ 0.4892, -0.8889, -2.2053]]])
>>> p=pred_1[mask]
>>> p
tensor([[-0.2948, -0.8348, 0.4989],
[-2.7751, 3.5035, -1.4127],
[-1.0469, 2.1227, 0.7483],
[ 0.4892, -0.8889, -2.2053]])
>>> class_mask = p.new(4,3).fill_(0)
>>> class_mask
tensor([[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]])
>>> p
tensor([[-0.2948, -0.8348, 0.4989],
[-2.7751, 3.5035, -1.4127],
[-1.0469, 2.1227, 0.7483],
[ 0.4892, -0.8889, -2.2053]])
>>> ids = target.view(-1, 1)
>>> ids
tensor([[ 0],
[ 1],
[ 0],
[ 2]])
>>> class_mask.scatter_(1,ids,1.)
tensor([[ 1., 0., 0.],
[ 0., 1., 0.],
[ 1., 0., 0.],
[ 0., 0., 1.]])
>>> probs = (p*class_mask).sum(1).view(-1,1)
>>> probs
tensor([[-0.2948],
[ 3.5035],
[-1.0469],
[-2.2053]])
>>> a=torch.ones(3,1)
>>> alpha=a[target]
>>> alpha
tensor([[ 1.],
[ 1.],
[ 1.],
[ 1.]])
>>> log_p = probs.log()
>>> log_p
tensor([[ nan],
[ 1.2538],
[ nan],
[ nan]])
>>> batch_loss = -alpha*(torch.pow((1-probs), 2))*log_p
>>> batch_loss
tensor([[ nan],
[-7.8580],
[ nan],
[ nan]])
>>>