torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”
区别于python numpy中的where()直接可以找到特定条件元素的index

想要实现numpy中where()的功能,可以借助nonzero()

对应numpy中的where()操作效果:

本文对比了PyTorch的where()函数与NumPy的where()功能,解释了两者之间的主要区别。PyTorch的where()用于结合两个可广播的tensor,而NumPy的where()则直接返回满足特定条件的元素索引。为了在PyTorch中实现类似NumPy where()的效果,可以使用nonzero()函数。
torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”
区别于python numpy中的where()直接可以找到特定条件元素的index

想要实现numpy中where()的功能,可以借助nonzero()

对应numpy中的where()操作效果:

1050

被折叠的 条评论
为什么被折叠?