创建一个数据y_hat
,其中包含2个样本在3个类别的预测概率,使用y
作为y_hat
中概率的索引
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
output:
tensor([0.1000, 0.5000])
创建一个数据y_hat
,其中包含2个样本在3个类别的预测概率,使用y
作为y_hat
中概率的索引
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
output:
tensor([0.1000, 0.5000])