- 官方文档 可供参考
torch.where(condition, input, other, *, out=None) → Tensor
- Return a tensor of elements selected from either
input
orother
, depending oncondition
. - Parameters:
condition (BoolTensor) – When True (nonzero), yield input, otherwise yield other
input (Tensor or Scalar) – value (if input is a scalar) or values selected at indices where condition is True
other (Tensor or Scalar) – value (if other is a scalar) or values selected at indices where condition is False
- Example:
>>> a = torch.arange(0,8).view(2,4)
>>> b = -torch.arange(0,8).view(2,4)
>>> torch.where(a>4, a, b)
tensor([[ 0, -1, -2, -3],
[-4, 5, 6, 7]])
>>> a>4
tensor([[False, F