【PyTorch】torch.index_select() 函数:某一维度上按索引选取元素

torch.index_select() 函数详解

torch.index_select() 是 PyTorch 中用于在 某一维度上按索引选取元素 的函数,适用于从张量中抽取特定的行、列或更高维度切片,是比切片更灵活的选择操作


1. 函数原型

torch.index_select(input, dim, index) → Tensor
参数说明
input输入张量
dim要在哪一维度上选取索引(例如:0 表示行,1 表示列)
index一个 1D 的 LongTensor,表示要选取的索引位置
返回值新张量,包含从 input 中选取的元素

2. 基本用法示例

2.1 按行选取(dim=0

import torch

x = torch.tensor([[10, 11],
                  [12, 13],
                  [14, 15]])

# 选第 2 行、第 0 行
idx = torch.tensor([2, 0])
result = torch.index_select(x, dim=0, index=idx)
print(result)

输出:

tensor([[14, 15],
        [10, 11]])

2.2 按列选取(dim=1

x = torch.tensor([[10, 11, 12],
                  [13, 14, 15]])

# 选第 2 列和第 0 列
idx = torch.tensor([2, 0])
result = torch.index_select(x, dim=1, index=idx)
print(result)

输出:

tensor([[12, 10],
        [15, 13]])

3. 应用场景

3.1 选择特定样本或特征

data = torch.randn(100, 64)  # 100 个样本,每个 64 维
selected = torch.index_select(data, dim=0, index=torch.tensor([0, 5, 10]))

3.2 从字典或嵌入中选择词向量

embedding = torch.randn(1000, 300)  # 假设词表大小为 1000,每个词 300 维
word_ids = torch.tensor([3, 7, 9])
word_vectors = torch.index_select(embedding, 0, word_ids)

4. 与其他函数的对比

函数功能特点
index_select按维度选择给定索引的切片输入索引必须是 1D
gather按每个位置指定索引选元素索引 shape 与输出 shape 相同
select选定某一维的一个具体位置只能选一个 index
slice连续索引(如 0:3)不支持不连续

5. 注意事项

  • index 必须是 1D 的 LongTensor
  • 返回的是一个 新的张量,原始张量不会改变。
  • dim 的值不能超过 input.ndim - 1

6. 总结

特性说明
用途在指定维度上按索引选取元素(如选特定行/列)
索引类型1D LongTensor
常用场景采样、提取子集、词向量查表、数据过滤
区别于 gather更适合全行/列的选择,而 gather 是逐元素选择

torch.index_select() 是构建数据选择、索引变换等操作的基本工具,特别适用于 NLP、图神经网络中嵌入提取、子图构造等任务。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值