一、where函数
1、简介
- return a tensor of elements selected from either x or y, depending on condition
- 返回一个tensor,根据生成规则选择是从x取出还是y取出。
- 虽然where函数操作可以用for循环代替,但是where是通过GPU加速的,在深度学习中使用where速度会更快。
2.函数
torch.where(cond,x,y)
cond
:生成规则自定义x,y
:两个tensor
3.代码实例
cond = torch.rand(2,2)
print(cond)
a = torch.full([2,2],0)
b = torch.full([2,2],1)
ans = torch.where(cond>0.5,a,b)
print(ans)
二、gather函数
1、简介
- 官方文档:
torch.gather(input,dim,index,out=None) -> Tensor
- Gather values along an axis specified by dim
- For a 3D tensor the output is specified by:
- out[i][j][k] = input[ index[i][j][k] ][j][k], dim = 0
- out[i][j][k] = input[i][ index[i][j][k] ][k], dim = 1
- out[i][j][k] = input[i][j][ index[i][j][k] ] , dim = 2
- 从原tensor中获取指定dim和指定index的数据
2、函数
torch.gather(input,dim,index,out=None)
input
:保持与index.shape
一致,从input中获取数值dim
:从input
中的哪一维度获取index
:保持与input.shape
一致,可以理解为tensor下标,根据index生成最终的tensor
3、代码实例
prob = torch.randn(4,10)
idx = prob.topk(dim=1,k=3)
idx = idx[1]
print(idx)
label = torch.arange(10) + 100
print(label)
ans = torch.gather(label.expand(4,10),dim=1,index=idx.long())
print(ans)