PyTorch张量检索张量,返回索引值

这篇博客介绍了如何在Python中处理张量,特别是通过遍历的方法寻找一个张量在另一个张量中的索引。提供的代码实现了一个函数`index_tensor_by_tensor`,用于在张量a中查找张量c的索引,如果找到则返回对应的索引值,否则返回-1。此外,还展示了如何应用这个函数到实际例子中,分别处理了存在匹配和不存在匹配的情况。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、问题描述

        假设有三个张量a、b和c,其中张量a的shape为(8, 2),张量b的shape为(8, 5),张量c的shape为(1, 2)。张量a和张量b是对应的关系,并且张量a和张量b每行元素都唯一、不会重复。而我需要通过检索的方式来获得张量c在张量a中的索引(索引值大于等于0),如果没有那么就返回一个异常值(-1)。例如

a = [ [1,2],

        [2,3],

        [3,4],

        [3,8],

        [4,5],

        [4,1],

        [5,6],

        [5,5] ]

b = [ [0, 0, 0, 0, 0],

        [0, 0, 0, 0, 1],

        [0, 0, 0, 1, 0],

        [0, 0, 0, 1, 1],

        [0, 0, 1, 0, 0],

        [0, 0, 1, 0, 1],

        [0, 0, 1, 1, 0],

        [0, 0, 1, 1, 1] ]

c = [ [4,1] ]

        我想要得到张量c在张量a中的索引值5,进而得到张量b中的[0, 0, 1, 0, 1]

二、解决方法

        代码(主要是采取遍历的思想,因为实在是找不到太好的api)

import torch

# input_: a 2d tensor
# query_: a 2d tensor
# function: get the index of query_ in input_
def index_tensor_by_tensor(input_, query_):
    # default index of result
    idx = -1
    # index range of tensor2d_a
    n_a = input_.shape[0]
    # traverse tensor2d_a
    for i in range(n_a):
        if input_[i][0] == query_[0][0] and input_[i][1] == query_[0][1]:
            # find the query tensor
            idx = i
            break
    return idx

# main function
if __name__ == '__main__':
    # input
    a = torch.tensor([[1, 2],
                      [2, 3],
                      [3, 4],
                      [3, 8],
                      [4, 5],
                      [4, 1],
                      [5, 6],
                      [5, 5]])
    b = torch.tensor([[0, 0, 0, 0, 0],
                      [0, 0, 0, 0, 1],
                      [0, 0, 0, 1, 0],
                      [0, 0, 0, 1, 1],
                      [0, 0, 1, 0, 0],
                      [0, 0, 1, 0, 1],
                      [0, 0, 1, 1, 0],
                      [0, 0, 1, 1, 1]])
    c_1 = torch.tensor([[4, 1]])    # it is exist in a
    c_2 = torch.tensor([[7, 1]])    # it is not exist in a

    # query
    res_i_1 = index_tensor_by_tensor(a, c_1)
    res_i_2 = index_tensor_by_tensor(a, c_2)

    # output
    if res_i_1 != -1:
        print(res_i_1, b.select(0, res_i_1))
    else:
        print(res_i_1, None)
    if res_i_2 != -1:
        print(res_i_2, b.select(0, res_i_2))
    else:
        print(res_i_2, None)

        效果

5 tensor([0, 0, 1, 0, 1])
-1 None

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

飞机火车巴雷特

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值