这里只介绍pytorch的高级索引,是一些奇怪的切片索引
基本版
a[[0, 2], [1, 2]] 等价 a[0, 1] 和 a[2, 2],相当于索引张量的第一行的第二列和第三行的第三列元素;
a[[1, 0, 2], [0]] 等价 a[1, 0] 和 a[0, 0] 和 a[2, 0],相当于索引张量的第二行第一列的元素、张量第一行和第一列的元素以及张量第三行和第一列的元素
import torch
a = torch.arange(9).view([3, 3])
print(a)
b = a[[0, 2], [1, 2]]
print(b)
c = a[[1, 0, 2], [0]]
print(c)
# ---------output----------
# tensor([[0, 1, 2],
# [3, 4, 5],
# [6, 7, 8]])
# tensor([1, 8])
# tensor([3, 0, 6])
# 这里参考了:https://zhuanlan.zhihu.com/p/509591863
高级索引的原则:索引中有: 就代表着改维度全部取,在哪个维度放置索引,就代表想取哪个维度的内容
扩展A:
import torch
a = torch.arange(30).view([

本文介绍了如何在PyTorch中使用高级索引来操作张量,包括基础索引、扩展A、扩展B和扩展C示例,以及torch.gather函数的应用。
最低0.47元/天 解锁文章
307

被折叠的 条评论
为什么被折叠?



