torch.gather()
函数详解
torch.gather()
是 PyTorch 中一个功能非常强大的函数,用于在指定维度上 根据索引张量从输入张量中收集元素。它在深度学习中常用于:
- 分类模型中选取目标类别的预测分数
- 掩码或注意力机制中抽取特定位置
- 多类标签、序列任务中对齐标签和预测
1. 函数原型
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
参数 | 说明 |
---|---|
input | 输入张量 |
dim | 沿着哪个维度进行索引 |
index | 索引张量,表示要从 input 中选取哪些元素(其 shape 必须与输出 shape 一致) |
sparse_grad | 是否返回稀疏梯度(通常设为 False) |
out | 可选输出张量 |
2. 功能说明
gather
会根据 index
中的每个元素,沿着 dim
维度去 input 中选择对应位置的值。
换句话说,对于某个位置 i,j,...
,返回:
[
\text{output}[i,j,…] = \text{input}[i,j,…, \text{index}[i,j,…]]
]
3. 示例讲解
3.1 在第 1 维(列)索引:dim=1
import torch
input = torch.tensor([[10, 20, 30],
[40, 50, 60]])
index = torch.tensor([[2, 1, 0],
[0, 2, 1]])
output = torch.gather(input, dim=1, index=index)
print(output)
输出:
tensor([[30, 20, 10],
[40, 60, 50]])
解释:
- 第 1 行取:第 2列(30)、第 1列(20)、第 0列(10)
- 第 2 行取:第 0列(40)、第 2列(60)、第 1列(50)
3.2 在第 0 维(行)索引:dim=0
input = torch.tensor([[10, 20, 30],
[40, 50, 60]])
index = torch.tensor([[1, 0, 0],
[0, 1, 1]])
output = torch.gather(input, dim=0, index=index)
print(output)
输出:
tensor([[40, 20, 30],
[10, 50, 60]])
4. 实际应用场景示例
4.1 分类模型中提取目标类别预测概率
# 假设模型输出 logits
logits = torch.tensor([[1.2, 0.5, 2.1],
[0.1, 3.2, 0.3]]) # shape: (batch, num_classes)
# 真实标签
labels = torch.tensor([[2], [1]]) # shape: (batch, 1)
# 提取每个样本对应标签的分数
selected_logits = torch.gather(logits, dim=1, index=labels)
print(selected_logits)
输出:
tensor([[2.1000],
[3.2000]])
5. 注意事项
index
的形状必须和输出形状一致。index
中的值必须是input
中对应维度上的合法索引(即不能越界)。dim
决定从哪一维度上选择数据。
6. 与 index_select()
的区别
特性 | gather | index_select |
---|---|---|
维度 | 可选择任意维度 | 固定选一维 |
索引形状 | 与输出形状一致 | 是一维 |
用途 | 高级索引、匹配标签等 | 简单抽取某些行/列 |
7. 总结
功能 | 说明 |
---|---|
动态选择元素 | 根据索引张量在某维度选择数据 |
支持广播 | 否,index 必须与输出形状一致 |
常用于 | 分类提取、注意力机制、Transformer、序列任务对齐等 |
torch.gather()
是构建深度学习模型中 动态索引操作的核心工具,特别是在分类、序列、注意力等任务中非常常见。