高级操作---where 和 gather

本文详细介绍了PyTorch中的两个高级函数:where和gather。where函数可以根据条件选择不同的tensor元素,而gather函数则用于从指定维度上收集tensor的元素。文章通过具体的代码示例展示了这两个函数的应用。
部署运行你感兴趣的模型镜像

一、where函数

1、简介

  1. return a tensor of elements selected from either x or y, depending on condition
  2. 返回一个tensor,根据生成规则选择是从x取出还是y取出。
  3. 虽然where函数操作可以用for循环代替,但是where是通过GPU加速的,在深度学习中使用where速度会更快。

2.函数

  1. torch.where(cond,x,y)
  2. cond:生成规则自定义
  3. x,y:两个tensor

3.代码实例

cond = torch.rand(2,2) # 倾向于a的概率
print(cond)
# tensor([[0.7894, 0.3648],
#        [0.4843, 0.3184]])

a = torch.full([2,2],0)

b = torch.full([2,2],1)
# 规则是如果cond>0.5就选择a,否则选择b
ans = torch.where(cond>0.5,a,b)
print(ans)
# tensor([[0., 1.],
#        [1., 1.]])

二、gather函数

1、简介

  1. 官方文档:
    1. torch.gather(input,dim,index,out=None) -> Tensor
    2. Gather values along an axis specified by dim
    3. For a 3D tensor the output is specified by:
      1. out[i][j][k] = input[ index[i][j][k] ][j][k], dim = 0
      2. out[i][j][k] = input[i][ index[i][j][k] ][k], dim = 1
      3. out[i][j][k] = input[i][j][ index[i][j][k] ] , dim = 2
  2. 从原tensor中获取指定dim和指定index的数据

2、函数

  1. torch.gather(input,dim,index,out=None)
  2. input:保持与index.shape一致,从input中获取数值
  3. dim:从input中的哪一维度获取
  4. index:保持与input.shape一致,可以理解为tensor下标,根据index生成最终的tensor

3、代码实例

prob = torch.randn(4,10) # 随机生成shape为[4,10]的tensor
idx = prob.topk(dim=1,k=3) # 生成前3大的tensor
idx = idx[1]
print(idx)
# tensor([[8, 7, 9],
#        [1, 3, 2],
#        [3, 1, 6],
#        [7, 8, 9]])

label = torch.arange(10) + 100 # 偏移
print(label)
# tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])

ans = torch.gather(label.expand(4,10),dim=1,index=idx.long()) # 对列索引

print(ans)
# tensor([[108, 107, 109],
#         [101, 103, 102],
#        [103, 101, 106],
#        [107, 108, 109]])

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

<think>我们首先需要理解用户的问题:用户希望了解PyTorch中创建张量的方法(特别是`scores.new_tensor(0)`),以及`torch.where``gather`函数的使用方法。 根据引用内容,我们可以参考: 引用[2]中提到了`torch.where``torch.masked_fill`等操作。 引用[1]引用[4]中提到了`gather``scatter`函数。 下面我们分三部分来回答: 1. `scores.new_tensor(0)` 创建张量 2. `torch.where` 函数的使用方法 3. `gather` 函数的使用方法 注意:用户还提到了`scatter`,但问题中只要求了`gather`,不过由于`gather``scatter`是互逆操作,我们可以简单提及,但重点在`gather`。 第一部分:`scores.new_tensor(0)` 在PyTorch中,`new_tensor`是一个方法,它根据已有的张量(这里是`scores`)创建一个新的张量,新张量的数据类型设备(CPU/GPU)与原有张量相同。 例如,`scores.new_tensor(0)`会创建一个标量张量,其值为0,并且这个新张量与`scores`在同一个设备上,具有相同的数据类型。 第二部分:`torch.where` 引用[2]中提到了`torch.where`,它的作用是根据条件选择两个张量中的元素。函数签名如下: torch.where(condition, x, y) 其中condition是一个布尔张量,当condition的某个位置为True时,输出张量在该位置取x的对应位置的值,否则取y的对应位置的值。 第三部分:`gather` 引用[1]引用[4]中提到了`gather`,它的作用是根据索引从输入张量中提取指定位置的元素。函数签名如下: torch.gather(input, dim, index, out=None, sparse_grad=False) 它沿着指定的维度dim,根据索引index从输入张量中提取元素。索引张量index的形状必须与输出张量相同,且每个位置的值表示在dim维度上的索引。 我们将通过代码示例详细说明。 注意:在回答中,我们按照要求使用LaTeX格式的数学表达式(如张量形状等),并正确引用。</think>### PyTorch张量创建与操作详解 #### 1. `scores.new_tensor(0)` 创建张量 `new_tensor()`方法基于现有张量创建新张量,保留原始张量的设备类型数据类型属性: ```python import torch # 原始张量(假设在GPU上) scores = torch.tensor([1.0, 2.0], device='cuda', dtype=torch.float32) # 创建新张量(继承设备类型数据类型) new_tensor = scores.new_tensor(0) # 标量0 print(new_tensor) # tensor(0., device='cuda:0') print(new_tensor.device) # cuda:0 print(new_tensor.dtype) # torch.float32 ``` **关键特性**: - 继承原始张量的设备(CPU/GPU)数据类型 - 输入可以是标量、列表、NumPy数组等 - 等价操作:`torch.tensor(0, device=scores.device, dtype=scores.dtype)` #### 2. `torch.where()` 条件选择 `torch.where()`根据条件选择两个张量的元素,语法为: ```python torch.where(condition, x, y) # 当condition为True时选x,否则选y ``` **示例:阈值处理**(引用[2]): ```python t = torch.tensor([[3, 7], [2, 5], [9, 1]]) # 大于5的元素保留,否则置0 result = torch.where(t > 5, t, torch.zeros_like(t)) print(result) # tensor([[0, 7], # [0, 0], # [9, 0]]) ``` **高级应用**: ```python # 创建掩码 mask = torch.tensor([[True, False], [False, True]]) a = torch.tensor([[1, 2], [3, 4]]) b = torch.tensor([[10, 20], [30, 40]]) # 根据掩码选择元素 combined = torch.where(mask, a, b) print(combined) # tensor([[ 1, 20], [30, 4]]) ``` #### 3. `torch.gather()` 索引收集 `gather()`根据索引在指定维度收集元素,语法为: ```python output = torch.gather(input, dim, index) ``` **基本用法**(引用[1][4]): ```python data = torch.tensor([[1, 2], [3, 4], [5, 6]]) # 3x2 index = torch.tensor([[0, 0], [1, 0], [0, 1]]) # 3x2 # 沿dim=1收集 result = torch.gather(data, 1, index) print(result) # tensor([[1, 1], # 取第0行: [1,2]中的[0,0]位置 # [4, 3], # 取第1行: [3,4]中的[1,0]位置 # [5, 6]]) # 取第2行: [5,6]中的[0,1]位置 ``` **实际应用:one-hot编码**(引用[4]): ```python labels = torch.tensor([2, 0, 1]) # 3个样本的类别标签 one_hot = torch.zeros(3, 3) # 3x3零矩阵 # 使用gather生成one-hot one_hot.scatter_(1, labels.unsqueeze(1), 1) print(one_hot) # tensor([[0., 0., 1.], # [1., 0., 0.], # [0., 1., 0.]]) ``` #### 4. 相关函数对比 | 函数 | 用途 | 特点 | |------|------|------| | `new_tensor()` | 创建继承属性的新张量 | 保留设备/数据类型[^1] | | `torch.where()` | 条件选择元素 | 类似三元运算符[^2] | | `torch.gather()` | 按索引收集数据 | 需指定维度[^1][^4] | | `torch.scatter()` | 按索引写入数据 | `gather`的逆操作[^1][^4] | > 引用说明:`new_tensor`继承设备类型数据类型属性[^1];`where`实现条件选择[^2];`gather``scatter`是索引操作的互补函数[^1][^4]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值