高级操作---where 和 gather

本文详细介绍了PyTorch中的两个高级函数:where和gather。where函数可以根据条件选择不同的tensor元素,而gather函数则用于从指定维度上收集tensor的元素。文章通过具体的代码示例展示了这两个函数的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、where函数

1、简介

  1. return a tensor of elements selected from either x or y, depending on condition
  2. 返回一个tensor,根据生成规则选择是从x取出还是y取出。
  3. 虽然where函数操作可以用for循环代替,但是where是通过GPU加速的,在深度学习中使用where速度会更快。

2.函数

  1. torch.where(cond,x,y)
  2. cond:生成规则自定义
  3. x,y:两个tensor

3.代码实例

cond = torch.rand(2,2) # 倾向于a的概率
print(cond)
# tensor([[0.7894, 0.3648],
#        [0.4843, 0.3184]])

a = torch.full([2,2],0)

b = torch.full([2,2],1)
# 规则是如果cond>0.5就选择a,否则选择b
ans = torch.where(cond>0.5,a,b)
print(ans)
# tensor([[0., 1.],
#        [1., 1.]])

二、gather函数

1、简介

  1. 官方文档:
    1. torch.gather(input,dim,index,out=None) -> Tensor
    2. Gather values along an axis specified by dim
    3. For a 3D tensor the output is specified by:
      1. out[i][j][k] = input[ index[i][j][k] ][j][k], dim = 0
      2. out[i][j][k] = input[i][ index[i][j][k] ][k], dim = 1
      3. out[i][j][k] = input[i][j][ index[i][j][k] ] , dim = 2
  2. 从原tensor中获取指定dim和指定index的数据

2、函数

  1. torch.gather(input,dim,index,out=None)
  2. input:保持与index.shape一致,从input中获取数值
  3. dim:从input中的哪一维度获取
  4. index:保持与input.shape一致,可以理解为tensor下标,根据index生成最终的tensor

3、代码实例

prob = torch.randn(4,10) # 随机生成shape为[4,10]的tensor
idx = prob.topk(dim=1,k=3) # 生成前3大的tensor
idx = idx[1]
print(idx)
# tensor([[8, 7, 9],
#        [1, 3, 2],
#        [3, 1, 6],
#        [7, 8, 9]])

label = torch.arange(10) + 100 # 偏移
print(label)
# tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])

ans = torch.gather(label.expand(4,10),dim=1,index=idx.long()) # 对列索引

print(ans)
# tensor([[108, 107, 109],
#         [101, 103, 102],
#        [103, 101, 106],
#        [107, 108, 109]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值