torch.index_select(input, dim, index, out=None) → Tensor
Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor.
The returned tensor has the same number of dimensions as the original tensor (input). The dimth dimension has the same size as the length of index; other dimensions have the same size as in the original tensor.
Note
The returned tensor does not use the same storage as the original tensor. If out has a different shape than expected, we silently change it to the correct shape, reallocating the underlying storage if necessary.
Parameters
-
input (Tensor) – the input tensor.
-
dim (int) – the dimension in which we index
-
index (LongTensor) – the 1-D tensor containing the indices to index
-
out (Tensor, optional) – the output tensor.
Example:
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-0.4664, 0.2647, -0.1228, -1.1068],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
[-1.1734, -0.6571, 0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
[-0.4664, -0.1228],
[-1.1734, 0.7230]])
torch.index_select函数用于根据指定维度上的索引返回一个新的张量,该张量与原始张量具有相同数量的维度,指定维度的大小与索引张量的长度相同,其他维度的大小保持不变。注意返回的张量不共享原始张量的存储。如果索引张量形状不正确,会自动调整并可能重新分配存储。
635

被折叠的 条评论
为什么被折叠?



