torch.where我所知道的用法有两种,
第一种为torch.where(condition,a,b),condition为条件判定,当满足条件时,对应位置为a的值否则为b的值,例如
import torch
a = torch.tensor([[0.5,2,1],
[1,1,1],
[0,3,2]])
b = torch.tensor([[0.6,1,1],
[2,3,2],
[4,5,1]])
c=torch.where(a>b,a,b)
#c的输出结果为
tensor([[0.6000, 2.0000, 1.0000],
[2.0000, 3.0000, 2.0000],
[4.0000, 5.0000, 2.0000]])
第二种使用方式是返回tensor中为true的索引,例如
import torch
a=torch.tensor([[0.5,2,1],
[1,1,1],
[0,3,2]])
b=torch.tensor([0.5,1,3])
#torch.eq(a,b[:,None])为判断值是否相等
tensor([[ True, False, False],
[ True, True, True],
[False, True, False]])
torch.where(torch.eq(a,b[:,None]))
#输出结果为
(tensor([0, 1, 1, 1, 2]), tensor([0, 0, 1, 2, 1]))
输出的结果分别对应true位置的行索引和列索引