【PyTorch】torch.gather() 函数:根据索引张量从输入张量中收集元素

torch.gather() 函数详解

torch.gather() 是 PyTorch 中一个功能非常强大的函数,用于在指定维度上 根据索引张量从输入张量中收集元素。它在深度学习中常用于:

  • 分类模型中选取目标类别的预测分数
  • 掩码或注意力机制中抽取特定位置
  • 多类标签、序列任务中对齐标签和预测

1. 函数原型

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
参数说明
input输入张量
dim沿着哪个维度进行索引
index索引张量,表示要从 input 中选取哪些元素(其 shape 必须与输出 shape 一致)
sparse_grad是否返回稀疏梯度(通常设为 False)
out可选输出张量

2. 功能说明

gather 会根据 index 中的每个元素,沿着 dim 维度去 input 中选择对应位置的值

换句话说,对于某个位置 i,j,...,返回:
[
\text{output}[i,j,…] = \text{input}[i,j,…, \text{index}[i,j,…]]
]


3. 示例讲解

3.1 在第 1 维(列)索引:dim=1

import torch

input = torch.tensor([[10, 20, 30],
                      [40, 50, 60]])

index = torch.tensor([[2, 1, 0],
                      [0, 2, 1]])

output = torch.gather(input, dim=1, index=index)
print(output)

输出:

tensor([[30, 20, 10],
        [40, 60, 50]])

解释:

  • 第 1 行取:第 2列(30)、第 1列(20)、第 0列(10)
  • 第 2 行取:第 0列(40)、第 2列(60)、第 1列(50)

3.2 在第 0 维(行)索引:dim=0

input = torch.tensor([[10, 20, 30],
                      [40, 50, 60]])

index = torch.tensor([[1, 0, 0],
                      [0, 1, 1]])

output = torch.gather(input, dim=0, index=index)
print(output)

输出:

tensor([[40, 20, 30],
        [10, 50, 60]])

4. 实际应用场景示例

4.1 分类模型中提取目标类别预测概率

# 假设模型输出 logits
logits = torch.tensor([[1.2, 0.5, 2.1],
                       [0.1, 3.2, 0.3]])  # shape: (batch, num_classes)

# 真实标签
labels = torch.tensor([[2], [1]])  # shape: (batch, 1)

# 提取每个样本对应标签的分数
selected_logits = torch.gather(logits, dim=1, index=labels)
print(selected_logits)

输出:

tensor([[2.1000],
        [3.2000]])

5. 注意事项

  • index 的形状必须和输出形状一致。
  • index 中的值必须是 input 中对应维度上的合法索引(即不能越界)。
  • dim 决定从哪一维度上选择数据。

6. 与 index_select() 的区别

特性gatherindex_select
维度可选择任意维度固定选一维
索引形状与输出形状一致是一维
用途高级索引、匹配标签等简单抽取某些行/列

7. 总结

功能说明
动态选择元素根据索引张量在某维度选择数据
支持广播否,index 必须与输出形状一致
常用于分类提取、注意力机制、Transformer、序列任务对齐等

torch.gather() 是构建深度学习模型中 动态索引操作的核心工具,特别是在分类、序列、注意力等任务中非常常见。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值