torch.gather结合torch.squeeze和torch.unsqueeze实现提取每行指定元素(可扩展至高维)
import torch
conf_t = torch.Tensor([[1,2,3],[2,3,4]])
print(conf_t)
idx = torch.LongTensor([2, 1])
print(idx)
idx2 = idx.unsqueeze(-1)# 增加1个维度,-1表示在最高维度上增加
print(idx2)
targets_weighted = torch.gather(conf_t, -1, idx2)# 以idx2为索引,聚集conf_t中对应的元素
print(targets_weighted)
targets_weighted = targets_weighted.squeeze(-1)# 降低1个维度,-1表示消去最高维
print(targets_weighted)
输出:
tensor([[1., 2., 3.],
[2., 3., 4.]])
tensor([2, 1])
tensor([[2],
[1]])
tensor([[3.],
[3.]])
tensor([3., 3.])