【pytorch】torch.cdist使用说明

torch.cdist函数用于批量计算两个向量集合之间的欧几里德距离,等同于scipy.spatial.distance.cdist。它接受输入的两个向量集合x1和x2,计算每个元素对之间的距离。当p=2时,是欧几里德距离;若p=1,则为L1范式。文中通过示例展示了如何使用该函数以及二维情况下的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

使用说明

torch.cdist的使用介绍如官网所示,

在这里插入图片描述

它是批量计算两个向量集合的距离。

其中, x1和x2是输入的两个向量集合。

p 默认为2,为欧几里德距离。

它的功能上等同于 scipy.spatial.distance.cdist(input,’minkowski’, p=p)

如果x1的shape是 [B,P,M], x2的shape是[B,R,M],则cdist的结果shape是 [B,P,R]

进一步的解释

x1一般是输入矢量,而x2一般是码本。

x2中所有的元素分别与x1中的每一个元素求欧几里德距离(当p默认为2时)

如下面示例

import torch

x1 = torch.FloatTensor([0.1, 0.2, 0, 0.5]).view(4, 1)

x2 = torch.FloatTensor([0.2, 0.3]).view(2, 1)

print(torch.cdist(x1,x2))

x2中的所有元素分别与x1中的每一个元素求欧几里德距离,即有如下步骤

x 11 = ( 0.1 − 0.2 ) 2 = 0.1 x 12 = ( 0.1 − 0.3 ) 2 = 0.2 x 21 = ( 0.2 − 0.2 ) 2 = 0 x 22 = ( 0.2 − 0.3 ) 2 = 0.1 x 31 = ( 0 − 0.2 ) 2 = 0.2 x 32 = ( 0 − 0.3 ) 2 = 0.3 x 41 = ( 0.5 − 0.2 ) 2 = 0.3 x 42 = ( 0.5 − 0.3 ) 2 = 0.2 x_{11} = \sqrt{ (0.1-0.2)^2} = 0.1 \newline x_{12} = \sqrt { (0.1-0.3)^2} = 0.2 \newline x_{21} = \sqrt { (0.2-0.2)^2} = 0 \newline x_{22} = \sqrt { (0.2-0.3)^2} = 0.1 \newline x_{31} = \sqrt { (0-0.2)^2} = 0.2 \newline x_{32} = \sqrt { (0-0.3)^2} = 0.3 \newline x_{41} = \sqrt { (0.5-0.2)^2 } =0.3\newline x_{42} = \sqrt { (0.5-0.3)^2 } = 0.2\newline x11=(0.10.2)2 =0.1x12=(0.10.3)2 =0.2x21=(0.20.2)2 =0x22=(0.20.3)2 =0.1x31=(00.2)2 =0.2x32=(00.3)2 =0.3x41=(0.50.2)2 =0.3x42=(0.50.3)2 =0.2

所以运行结果为
在这里插入图片描述

扩张到2维的情况

如下面示例

import torch

x1 = torch.FloatTensor([0.1, 0.2, 0.1, 0.5, 0.2, -0.9, 0.8, 0.4]).view(4, 2)

x2 = torch.FloatTensor([0.2, 0.3, 0, 0.1]).view(2, 2)

print(torch.cdist(x1,x2))

x1和x2数据是二维的,
在这里插入图片描述

x2中的所有元素分别与x1中的每一个元素求欧几里德距离,即有如下步骤

x 11 = ( 0.1 − 0.2 ) 2 + ( 0.2 − 0.3 ) 2 = 0.02 = 0.1414 x 12 = ( 0.1 − 0.0 ) 2 + ( 0.2 − 0.1 ) 2 = 0.02 = 0.1414 x 21 = ( 0.1 − 0.2 ) 2 + ( 0.5 − 0.3 ) 2 = 0.05 = 0.2236 x 22 = ( 0.1 − 0.0 ) 2 + ( 0.5 − 0.1 ) 2 = 0.17 = 0.4123 x 31 = ( 0.2 − 0.2 ) 2 + ( − 0.9 − 0.3 ) 2 = 1.2 x 32 = ( 0.2 − 0.0 ) 2 + ( − 0.9 − 0.1 ) 2 = ( 1.04 ) = 1.0198 x 41 = ( 0.8 − 0.2 ) 2 + ( 0.4 − 0.3 ) 2 = ( 0.37 ) = 0.6083 x 42 = ( 0.8 − 0.0 ) 2 + ( 0.4 − 0.1 ) 2 = ( 0.73 ) = 0.8544 x_{11} = \sqrt{ (0.1-0.2)^2 + (0.2-0.3)^2 } = \sqrt{0.02} = 0.1414 \newline x_{12} = \sqrt { (0.1-0.0)^2 + (0.2-0.1)^2 } = \sqrt{0.02} = 0.1414 \newline x_{21} = \sqrt { (0.1-0.2)^2 + (0.5-0.3)^2 } = \sqrt{0.05} = 0.2236 \newline x_{22} = \sqrt { (0.1-0.0)^2 + (0.5-0.1)^2 } = \sqrt{0.17} = 0.4123 \newline x_{31} = \sqrt { (0.2-0.2)^2 + (-0.9-0.3)^2} = 1.2 \newline x_{32} = \sqrt { (0.2-0.0)^2 + (-0.9-0.1)^2} = \sqrt(1.04) = 1.0198 \newline x_{41} = \sqrt { (0.8-0.2)^2 + (0.4-0.3)^2 } = \sqrt(0.37) = 0.6083 \newline x_{42} = \sqrt { (0.8-0.0)^2 + (0.4-0.1)^2 } = \sqrt(0.73) = 0.8544 \newline x11=(0.10.2)2+(0.20.3)2 =0.02 =0.1414x12=(0.10.0)2+(0.20.1)2 =0.02 =0.1414x21=(0.10.2)2+(0.50.3)2 =0.05 =0.2236x22=(0.10.0)2+(0.50.1)2 =0.17 =0.4123x31=(0.20.2)2+(0.90.3)2 =1.2x32=(0.20.0)2+(0.90.1)2 =( 1.04)=1.0198x41=(0.80.2)2+(0.40.3)2 =( 0.37)=0.6083x42=(0.80.0)2+(0.40.1)2 =( 0.73)=0.8544

所以结果如下

在这里插入图片描述

p=2的欧几里德距离也是L2范式,如果p=1即是L1范式
上面的例子修改一下p参数

import torch

x1 = torch.FloatTensor([0.1, 0.2, 0.1, 0.5, 0.2, -0.9, 0.8, 0.4]).view(4, 2)

x2 = torch.FloatTensor([0.2, 0.3, 0, 0.1]).view(2, 2)

print(torch.cdist(x1,x2,p=1))

结果如下,这里就不一个一个运算了。
在这里插入图片描述

详细解释一下这段代码,每一句都要进行注解:def get_image_pairs_shortlist(fnames, sim_th = 0.6, # should be strict min_pairs = 20, exhaustive_if_less = 20, device=torch.device('cpu')): num_imgs = len(fnames) if num_imgs <= exhaustive_if_less: return get_img_pairs_exhaustive(fnames) model = timm.create_model('tf_efficientnet_b7', checkpoint_path='/kaggle/input/tf-efficientnet/pytorch/tf-efficientnet-b7/1/tf_efficientnet_b7_ra-6c08e654.pth') model.eval() descs = get_global_desc(fnames, model, device=device) #这段代码使用 PyTorch 中的 torch.cdist 函数计算两个矩阵之间的距离,其中参数 descs 是一个矩阵,表示一个数据集中的所有样本的特征向量。函数将计算两个矩阵的 p 范数距离,即对于矩阵 A 和 B,其 p 范数距离为: #dist_{i,j} = ||A_i - B_j||_p #其中 i 和 j 分别表示矩阵 A 和 B 中的第 i 和 j 行,||.||_p 表示 p 范数。函数的返回值是一个矩阵,表示所有样本之间的距离。 # detach() 和 cpu() 方法是为了将计算结果从 GPU 转移到 CPU 上,并将其转换为 NumPy 数组。最终的结果将会是一个 NumPy 数组。 dm = torch.cdist(descs, descs, p=2).detach().cpu().numpy() # removing half mask = dm <= sim_th total = 0 matching_list = [] ar = np.arange(num_imgs) already_there_set = [] for st_idx in range(num_imgs-1): mask_idx = mask[st_idx] to_match = ar[mask_idx] if len(to_match) < min_pairs: to_match = np.argsort(dm[st_idx])[:min_pairs] for idx in to_match: if st_idx == idx: continue if dm[st_idx, idx] < 1000: matching_list.append(tuple(sorted((st_idx, idx.item())))) total+=1 matching_list = sorted(list(set(matching_list))) return matching_list
06-01
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值