给一些点的坐标(x1, y1), (x2, y2) … ,N个点,如果求它们中每个点到其他所有点的距离,
首先能想到的办法就是两层for循环,时间复杂度是O(N2)
而矩阵运算是经过优化的,一般来说比for循环要快。
现举例说明如何用矩阵运算来计算每个点到其他所有点的距离。
现有3个点
>>> c = torch.Tensor([[1,2], [2,3],[2,4]])
>>> c
tensor([[1., 2.],
[2., 3.],
[2., 4.]])
>>> c.shape
torch.Size([3, 2])
这时分别在dim=0, dim=1上增加一个维度
>>> c1 = torch.unsqueeze(c, dim=1)
>>> c1.shape
torch.Size([3, 1, 2])
>>> c2 = c[None, ...]
>>> c2.shape
torch.Size([1, 3, 2])
用c1 - c2看会得到什么
>>> c1 - c2
tensor([[[ 0., 0.],
[-1., -1.],
[-1., -2.]],
[[ 1., 1.],
[ 0., 0.],
[ 0., -1.]],
[[ 1., 2.],
[ 0., 1.],
[ 0., 0.]]])
我们知道,python有broadcast功能,c1为[1,3,2], c2为[3,1,2], 相减时c1的dim=0维度会用同样的数据复制3次,变为[3,3,2],
c2的dim=1维度也会复制3次变为[3,3,2]。
那么相减就是shape为[3,3,2]。
可看到结果为每个点的(x, y) 减去其他所有点的(x, y)。
而求距离时希望差值的(x, y)求平方和,也就是x2+y2,
对最里面的维度求平方和,也就是[3,3,2]的最后一个维度2求平方和。
>>> torch.sum((c1 - c2)**2, dim=-1)
tensor([[0., 2., 5.],
[2., 0., 1.],
[5., 1., 0.]])
按列看,可看到每一列就是方才c1 - c2中每个差值(x, y)的x2+y2
最后开根号即可。
>>> torch.sum((c1 - c2)**2, dim=-1) ** 0.5
tensor([[0.0000, 1.4142, 2.2361],
[1.4142, 0.0000, 1.0000],
[2.2361, 1.0000, 0.0000]])