**
torch.gather使用
**
import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
#tensor([[ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
**
torch.searchsorted使用
**
import torch.nn as nn
import torch
def main():
x = torch.linspace(1, 11, 10)
y = torch.linspace(2, 12, 10)
x=x.unsqueeze(-1).expand(x.size(0),10)
y=y.unsqueeze(-1).expand(y.size(0),100)
print(x.shape,y.shape)
res1 = torch.searchsorted(x, y)
print(res1)
main()
# torch.Size([10, 10]) torch.Size([10, 100])
# tensor([[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
# 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
# 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
# 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
# 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
# 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
# ...
# [[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
# 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
# 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
# 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
# 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
# 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]])
torch.searchsorted()函数的作用是x经过排序过后,y在x中寻找自己的每个元素的位置,输出的维度与y保持一致。