pytorch 矩阵运算巧求点之间的距离

本文介绍了利用矩阵操作优化计算N个点之间距离的方法。通过将点集转换为张量并利用广播机制,将时间复杂度从O(N^2)降低到线性,展示了如何使用PyTorch进行点距计算,并给出了具体步骤和示例结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

给一些点的坐标(x1, y1), (x2, y2) … ,N个点,如果求它们中每个点到其他所有点的距离,
首先能想到的办法就是两层for循环,时间复杂度是O(N2)

而矩阵运算是经过优化的,一般来说比for循环要快。
现举例说明如何用矩阵运算来计算每个点到其他所有点的距离。

现有3个点

>>> c = torch.Tensor([[1,2], [2,3],[2,4]])
>>> c
tensor([[1., 2.],
        [2., 3.],
        [2., 4.]])
>>> c.shape 
torch.Size([3, 2])

这时分别在dim=0, dim=1上增加一个维度

>>> c1 = torch.unsqueeze(c, dim=1)
>>> c1.shape
torch.Size([3, 1, 2])
>>> c2 = c[None, ...]
>>> c2.shape
torch.Size([1, 3, 2])

用c1 - c2看会得到什么

>>> c1 - c2
tensor([[[ 0.,  0.],
         [-1., -1.],
         [-1., -2.]],

        [[ 1.,  1.],
         [ 0.,  0.],
         [ 0., -1.]],

        [[ 1.,  2.],
         [ 0.,  1.],
         [ 0.,  0.]]])

我们知道,python有broadcast功能,c1为[1,3,2], c2为[3,1,2], 相减时c1的dim=0维度会用同样的数据复制3次,变为[3,3,2],
c2的dim=1维度也会复制3次变为[3,3,2]。
那么相减就是shape为[3,3,2]。
可看到结果为每个点的(x, y) 减去其他所有点的(x, y)。

而求距离时希望差值的(x, y)求平方和,也就是x2+y2,
对最里面的维度求平方和,也就是[3,3,2]的最后一个维度2求平方和。

>>> torch.sum((c1 - c2)**2, dim=-1)
tensor([[0., 2., 5.],
        [2., 0., 1.],
        [5., 1., 0.]])

按列看,可看到每一列就是方才c1 - c2中每个差值(x, y)的x2+y2
最后开根号即可。

>>> torch.sum((c1 - c2)**2, dim=-1) ** 0.5
tensor([[0.0000, 1.4142, 2.2361],
        [1.4142, 0.0000, 1.0000],
        [2.2361, 1.0000, 0.0000]])
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

蓝羽飞鸟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值