使用pytorch实现tf.batch_gather:
def batch_gather(data:torch.Tensor, index:torch.Tensor):
length = index.shape[0]
t_index = index.data.numpy()
t_data = data.data.numpy()
result = []
for i in range(length):
result.append(t_data[i, t_index[i], :])
return torch.from_numpy(np.array(result))
本文介绍了一种使用PyTorch实现TensorFlow中tf.batch_gather功能的方法。通过定义一个名为batch_gather的函数,该函数接受两个参数:data和index。此函数将数据转换为NumPy数组,然后遍历索引,从原始数据中收集对应元素,并将结果重新转换为PyTorch张量返回。
2772

被折叠的 条评论
为什么被折叠?



