【PyTorch】torch.select() 函数:某一维度上 选择一个特定索引位置的切片

torch.select() 函数详解

torch.select() 是 PyTorch 中一个较少直接使用但功能明确的函数,用于在某一维度上 选择特定索引位置的切片只能选择一个 index,不支持多个或批量选择


1. 函数原型

torch.select(input, dim, index) → Tensor
参数说明
input输入张量
dim指定在哪个维度上选取
index在该维度上的索引位置(只能是一个整数)
返回值一个新的张量,维度比 input 少一维(去除了 dim

2. 功能说明

  • dim 指定的维度上,选取第 index 个元素
  • 与切片类似于 x[dim=index],但更通用、可编程;
  • 返回的张量将会 去除该维度(即降维)

3. 示例讲解

3.1 示例:按行选择(dim=0

import torch

x = torch.tensor([[10, 11, 12],
                  [13, 14, 15],
                  [16, 17, 18]])

row = torch.select(x, dim=0, index=1)
print(row)  # tensor([13, 14, 15])

等价于:

x[1]

3.2 示例:按列选择(dim=1

column = torch.select(x, dim=1, index=2)
print(column)
# 输出: tensor([12, 15, 18])

等价于:

x[:, 2]

4. 与其他索引函数的对比

函数功能特点
select选某个维度的一个 index只能选一个索引,返回张量降维
index_select在某维选多个 index(1D LongTensor)支持批量选择,返回保持维度
gather在某维按 index 张量选择每个位置的值支持高级索引和批量抽取

5. 使用场景

  • 对于只想选取某一维的一个特定切片时,使用 select() 更清晰:
    • 比如取第 i 张图片的第 j 个通道;
    • 从序列中取第 t 个时间步的隐藏状态;
  • 在编程接口中比 x[i]x[:, i] 更通用。

6. 注意事项

  • index 必须是 单个整数,不是张量。
  • 返回张量会 降维(该维度被移除)
  • 如果要保留维度,可使用 x.index_select(dim, torch.tensor([index]))

7. 总结

特性说明
功能沿指定维度选取一个位置的切片
返回张量(比原张量少一维)
优点简洁、高效、可编程
适合场景选某一行、某一列、某个时间步等单点切片场景

虽然 torch.select() 功能简单,但在构建动态维度操作或作为底层组件时很有用。它在自定义模型结构、动态时间步处理等任务中偶尔会派上用场。如需更复杂的选择操作,推荐使用 index_select()gather()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值