torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”
区别于python numpy中的where()直接可以找到特定条件元素的index

想要实现numpy中where()的功能,可以借助nonzero()

对应numpy中的where()操作效果:

本文对比了PyTorch的where()函数与NumPy的where()功能,解释了两者之间的主要区别。PyTorch的where()用于结合两个可广播的tensor,而NumPy的where()则直接返回满足特定条件的元素索引。为了在PyTorch中实现类似NumPy where()的效果,可以使用nonzero()函数。
torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”
区别于python numpy中的where()直接可以找到特定条件元素的index

想要实现numpy中where()的功能,可以借助nonzero()

对应numpy中的where()操作效果:

您可能感兴趣的与本文相关的镜像
PyTorch 2.5
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理
1052
1740

被折叠的 条评论
为什么被折叠?