def th_gather_nd(x, coords):
x = x.contiguous()
inds = coords.mv(torch.LongTensor(x.stride()))
x_gather = torch.index_select(x.contiguous().view(-1), 0, inds)
return x_gather
tf.gather_nd 在 torch 中的实现
最新推荐文章于 2023-06-05 00:16:11 发布