fluid.layers.gather(input, index, overwrite=True)
- 作用:根据索引 index 获取输入(input)的最外层维度的条目,并将它们拼接在一起。
- 链接:pp飞桨API说明
- 对比:通常来说此api用法和pytorch没什么特殊,但因为API缺少了dim参数,使得当多维tensor索引时报错
示例
输入维度
- X.shape = [b, c, 81204]
- Index.shape = [b, c, 720000]
输出维度 - X_offset = [b, c, 720000]
# pytorch code
x_offset = x.gather(dim=-1, index=index)

博客对比了PaddlePaddle和PyTorch中gather API的用法。Paddle的gather API缺少dim参数,在多维tensor索引时会报错。文中给出输入输出维度示例,总结出在Paddle中需采用for、reshape和concat相结合的方法来处理。
最低0.47元/天 解锁文章
1830

被折叠的 条评论
为什么被折叠?



