torch.index_select()
函数详解
torch.index_select()
是 PyTorch 中用于在 某一维度上按索引选取元素 的函数,适用于从张量中抽取特定的行、列或更高维度切片,是比切片更灵活的选择操作。
1. 函数原型
torch.index_select(input, dim, index) → Tensor
参数 | 说明 |
---|---|
input | 输入张量 |
dim | 要在哪一维度上选取索引(例如:0 表示行,1 表示列) |
index | 一个 1D 的 LongTensor,表示要选取的索引位置 |
返回值 | 新张量,包含从 input 中选取的元素 |
2. 基本用法示例
2.1 按行选取(dim=0
)
import torch
x = torch.tensor([[10, 11],
[12, 13],
[14, 15]])
# 选第 2 行、第 0 行
idx = torch.tensor([2, 0])
result = torch.index_select(x, dim=0, index=idx)
print(result)
输出:
tensor([[14, 15],
[10, 11]])
2.2 按列选取(dim=1
)
x = torch.tensor([[10, 11, 12],
[13, 14, 15]])
# 选第 2 列和第 0 列
idx = torch.tensor([2, 0])
result = torch.index_select(x, dim=1, index=idx)
print(result)
输出:
tensor([[12, 10],
[15, 13]])
3. 应用场景
3.1 选择特定样本或特征
data = torch.randn(100, 64) # 100 个样本,每个 64 维
selected = torch.index_select(data, dim=0, index=torch.tensor([0, 5, 10]))
3.2 从字典或嵌入中选择词向量
embedding = torch.randn(1000, 300) # 假设词表大小为 1000,每个词 300 维
word_ids = torch.tensor([3, 7, 9])
word_vectors = torch.index_select(embedding, 0, word_ids)
4. 与其他函数的对比
函数 | 功能 | 特点 |
---|---|---|
index_select | 按维度选择给定索引的切片 | 输入索引必须是 1D |
gather | 按每个位置指定索引选元素 | 索引 shape 与输出 shape 相同 |
select | 选定某一维的一个具体位置 | 只能选一个 index |
slice | 连续索引(如 0:3) | 不支持不连续 |
5. 注意事项
index
必须是 1D 的 LongTensor。- 返回的是一个 新的张量,原始张量不会改变。
dim
的值不能超过input.ndim - 1
。
6. 总结
特性 | 说明 |
---|---|
用途 | 在指定维度上按索引选取元素(如选特定行/列) |
索引类型 | 1D LongTensor |
常用场景 | 采样、提取子集、词向量查表、数据过滤 |
区别于 gather | 更适合全行/列的选择,而 gather 是逐元素选择 |
torch.index_select()
是构建数据选择、索引变换等操作的基本工具,特别适用于 NLP、图神经网络中嵌入提取、子图构造等任务。