torch.gather的使用及理解

结论:使用方法

# gather,沿dim指定的轴收集值。
y_hat.gather(1, y.view(-1, 1))# y.view(-1, 1)会变成一列,y_hat的取y作为的索引的值

分步理解:先创建一个2*3的tensor

>>y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])

tensor([[0.1000, 0.3000, 0.6000],
        [0.3000, 0.2000, 0.5000]])

为了使用gather函数,我们得创建一个tensor作为gather得参数

>>y = torch.LongTensor([0, 2])

tensor([0, 2])

我们需要把y变个形状

>>y.view(-1, 1)

tensor([[0],
        [2]])

先来看看使用得结果

>>y_hat.gather(1, y.view(-1, 1))
tensor([[0.1000],
        [0.5000]])

图解:

 

 

 

 
### PyTorch 中 `torch.gather` 和 `repeat` 的搭配用法 #### 什么是 `torch.gather`? `torch.gather` 是 PyTorch 中的一个操作,用于从输入张量的不同维度收集指定索引位置上的元素。它允许按照给定的索引来提取特定的数据片段。 其基本语法为: ```python torch.gather(input, dim, index, out=None, sparse_grad=False) ``` 其中: - `input`: 输入张量。 - `dim`: 要沿着哪个维度进行聚集操作。 - `index`: 指定要采集的索引值,形状需与目标一致[^1]。 #### 什么是 `repeat` 或 `.expand()` 方法? `.repeat()` 是一种扩展张量的方法,可以重复某个张量的内容以匹配新的形状。它的作用类似于 NumPy 的 `tile` 函数。对于简单的广播机制不足的情况,可以通过 `.repeat()` 来实现更复杂的形状调整。 其基本语法为: ```python tensor.repeat(*sizes) ``` 参数 `*sizes` 表示各个维度上需要重复的次数[^2]。 --- #### 组合使用场景分析 当我们将 `torch.gather` 和 `repeat` 结合起来时,通常是为了处理以下情况之一: 1. **动态索引选择并复制**:先通过 `gather` 收集某些特定数据,再利用 `repeat` 将这些数据沿某一维度扩展到更大的规模。 2. **批量操作中的灵活映射**:在批处理模式下,针对不同样本执行个性化的特征抽取或变换。 下面是一个具体的例子展示如何组合这两个功能。 --- #### 使用案例 假设有一个二维矩阵 A,以及一组对应每一行的最大值索引 B。我们需要根据索引 B 提取每行最大值,并将其扩展成一个新的三维张量 C。 ##### 示例代码 ```python import torch # 创建一个随机二维张量 A (batch_size=3, feature_dim=4) A = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=torch.float) # 假设我们知道每行的最大值对应的列索引 B = torch.tensor([3, 2, 1]) # shape: (batch_size,) -> [3, 2, 1] # Step 1: 使用 gather 获取每行的最大值 max_values_per_row = torch.gather(A, 1, B.unsqueeze(1)) # unsqueeze 使 B 变为 (batch_size, 1) print("Max values per row:", max_values_per_row.squeeze()) # Output: tensor([4., 7., 10.]) # Step 2: 扩展结果至更高维空间 C = max_values_per_row.expand(-1, 4).unsqueeze(2) # expand 到 (batch_size, 4), 并增加一维变为 (batch_size, 4, 1) D = C.repeat(1, 1, 3) # 进一步 repeat 成 (batch_size, 4, 3) print(D.shape) # 输出应为 torch.Size([3, 4, 3]) ``` 上述代码展示了如何结合 `torch.gather` 和 `repeat` 完成复杂的数据转换任务。这里的关键在于理解 `gather` 如何定位所需数据,以及 `repeat` 怎样改变张量结构[^3]。 --- #### 注意事项 1. 当调用 `gather` 时,确保索引张量 (`index`) 的尺寸与目标维度相兼容。 2. 对于高阶张量的操作,务必注意各维度的意义及其顺序关系。 3. 如果涉及 GPU 计算,请确认所有参与运算的对象均位于同一设备之上[^4]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

dlage

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值