三维点云中无监督关键点的Separation loss损失函数代码手撕

最近在仔细阅读SC3K论文内的源码,发现他的分离损失的代码表达式特别精炼,通过运用大量的技巧来以很少的代码去最终实现目的,直接上手:

一、Separation Loss公式表达式

L s e p = 1 m a x ( 1 K ∑ i = 1 K ∣ ∣ k i − k N N ( k i , K P ) ∣ ∣ 2 , 0.01 ) L_{sep} = \frac{1}{max(\frac{1}{K}\sum_{i=1}{K}||k_i-kNN(k_i,KP)||_2, 0.01)} Lsep=max(K1i=1K∣∣kikNN(ki,KP)2,0.01)1

其中K表示关键点的个数, ki 表示第i个关键点的信息,这里关键点信息都是以 (x,y,z) 进行表征的, kNN 是计算第i个关键点ki最近的一个关键点的映射作用, KP=(k1,k2,...,kK) K个关键点的点集;

二、代码解读的基础知识

def separation_loss(kp):
   min_distances = torch.cat([torch.squeeze(
           torch.norm(kp[i].unsqueeze(1) - kp[i].unsqueeze(0), dim=2, p=None).topk(2, largest=False, dim=0)[
               0]) for i in range(len(kp))], dim=0)

   return 1/torch.mean(min_distances[min_distances>0])

img

仔细思考下这个代码以及上述公式,系考完后你会不会变成上面这位

那好,先逐步拆解下他的具体含义:

def separation_loss(kp):

这行代码就是def头,定义的方法名叫separation_loss,并且必须得传入kp这个参数,才能确保方法能正确调用,其中传入的参数kp是一个N*3维的关键点坐标,要是没接触的话你就想象成一系列三维空间中的点坐标矩阵,其中行的N表示有N个点,列为3表示(x,y,z)三个坐标数值。

min_distances = torch.cat([torch.squeeze(
           torch.norm(kp[i].unsqueeze(1) - kp[i].unsqueeze(0), dim=2, p=None).topk(2, largest=False, dim=0)[
               0]) for i in range(len(kp))], dim=0)

这次直接上重头戏,一开始这么分析是肯定不行的,我分析他的思路是先把死的技巧描述下,再以案例分析去理解每一个过程,最后再统一串起来:

这个代码用到torch.cat(xx, dim=0)、torch.squeeze(xx)、torch.norm(xx, dim=2, p=None)、torch.topk(2, largest=False,dim=0)

1.torch.cat

关于torch.cat(xx,dim=0)的话,以pytorch的官网案例:

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])

为了更清晰地知道他的轴变化情况,我们也以3维情形进行分析:

>>>x = torch.tensor([[[1,2,3]], [[4,5,6]], [[7,8,9]]])
>>>row, colum, z = x.shape
>>>print(row, colum, z) #返回3,1,3,即这个张量的维度
>>>torch.cat((x,x,x),dim=0)
tensor([[[1, 2, 3]],
        [[4, 5, 6]],
        [[7, 8, 9]],
        [[1, 2, 3]],
        [[4, 5, 6]],
        [[7, 8, 9]],
        [[1, 2, 3]],
        [[4, 5, 6]],
        [[7, 8, 9]]])

>>>torch.cat((x,x,x),dim=1)
tensor([[[1, 2, 3],
         [1, 2, 3],
         [1, 2, 3]],
        [[4, 5, 6],
         [4, 5, 6],
         [4, 5, 6]],
        [[7, 8, 9],
         [7, 8, 9],
         [7, 8, 9]]])

>>>torch.cat((x,x,x),dim=2)
tensor([[[1, 2, 3, 1, 2, 3, 1, 2, 3]],
        [[4, 5, 6, 4, 5, 6, 4, 5, 6]],
        [[7, 8, 9, 7, 8, 9, 7, 8, 9]]])

我对该轴的插入法则的理解的话:沿着某个轴进行插入,就先把该层级的[]这个括号给剔除,并且需要严格遵守数据x的一一对应关系进行拼接;

2.torch.squeeze(xx)、torch.unsqueeze(xx,dim)

该函数是用于处理变量xx的维度,只要xx在某些维度上是1,那么就会将是1的维度剔除,将其余均不为1的轴进行拼接,先看看他的基本用法,如下为pytorch官网的案例:

>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0) #其中当第二个参数为0,就不会剔除维度上是1的对应轴(轴在这里也可以理解为维度)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])
>>> y = torch.squeeze(x, (1, 2, 3))
torch.Size([2, 2, 2])

写到这里我就又想去搞明白这个轴的变化反应在具体数值张量上的表达形式是什么样子,为了实现我这个好奇心,这里就一道介绍掉torch.unsqueeze(xx,dim):

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1,  2,  3,  4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])

该函数的作用是在指定的轴上新增一个维度,维度只能为1,注意多看看[]这个的变化!

>>>x = torch.tensor([[[1,2,3]], [[4,5,6]], [[7,8,9]]])
>>>row, colum, z = x.shape
tensor([[[1, 2, 3]],
        [[4, 5, 6]],
        [[7, 8, 9]]]) #维度为(3,1,3),这个1的话要这么数:第二个括号内只有一个同级的元素,所以就为1!!!第三个括号内有3个同级的元素,所以维度为3!!!

>>> y = x.unsqueeze(1)
tensor([[[[1, 2, 3]]],
        [[[4, 5, 6]]],
        [[[7, 8, 9]]]]) #注意这里维度变为了(3,1,1,3)

>>>z = torch.squeeze(y, 0)
>>>z.size() #输出为torch.Size([3, 1, 1, 3]),因为我们这指定了轴为0的地方,而y在轴为0的地方处维度是3,不为1,所以无效

>>>a = torch.squeeze(y, 1)
torch.Size([3, 1, 3])
tensor([[[1, 2, 3]],
        [[4, 5, 6]],
        [[7, 8, 9]]]) #可以发现他的轴与维度变化,剔除了原始轴为1时维度也为1的那个地方,并且直接进行了拼接

>>>b = torch.squeeze(y, 2)
torch.Size([3, 1, 3])
tensor([[[1, 2, 3]],
        [[4, 5, 6]],
        [[7, 8, 9]]]) #y的轴为2的地方维度也是1,所以也剔除掉了,虽然直接看他和a的变化没啥区别,但得知道这里是去掉了第三个[]才得到该结果的

>>>c = torch.squeeze(y, 3)
torch.Size([3, 1, 1, 3])
tensor([[[[1, 2, 3]]],
        [[[4, 5, 6]]],
        [[[7, 8, 9]]]]) #轴为3的地方维度不是1哦!!!所以不变

>>>d = torch.squeeze(y)
torch.Size([3, 3])
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]]) #这里使用squeez的时候没有指定维度,直接将所有轴中维度为1的轴全部剔除掉了!!!

由上案例,我们可以直接得出:不用考虑哪个[],只要有几个维度是1,就照常无脑给他去掉几个[]就行

差不多就和上面一样,随着对应轴的变化,那么对应的[]也会新增或剔除;

3.torch.norm(xx, dim=2, p=None)

这个函数就是计算范数公式的,其中xx表示输入数据,我们这儿就是关键点数据,是xyz三维的坐标数据,其中dim=2表示在第三个轴上进行计算,p=None表示默认是欧几里得距离的计算公式

>>>xx = torch.tensor([[[1., 2., 3., 4.],
                       [5., 6., 7., 8.],
                       [9., 10., 11., 12.]],
                      [[13., 14., 15., 16.],
                       [17., 18., 19., 20.],
                       [21., 22., 23., 24.]]])
# 计算第 3 维(dim=2)的 L2 范数
>>>norm_result = torch.norm(xx, dim=2, p=None)
tensor([[ 5.4772, 13.9282, 22.6274],
        [30.4106, 38.9374, 47.5146]]) #其中norm_result[0][0]=sqrt(1**2+2**2+3**2+4**2)
norm_result[1][0]=sqrt(13**2+14**2+15**2+16**2)  以此类推,他这里是在第三个轴处进行计算

4.torch.topk(2, largest=False, dim=0)

torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
这是pytorch官网内提供的该方法的参数,其中input表示我们要输入的数据,k表示我们要提取前k个内容;

dim=None表示不指定特定维度,那么就会默认按照输入的最后一维进行运算比较来提取最小或最大值以及对应的索引,
我们这儿dim=0,表示指定第一个轴对应的维度进行比较,即行方向,按每一列进行比较;

largest = True表示返回规则为最大值,若largest = false则返回规则为返回最小值;

sorted=True表示返回的内容是按照升序的规则将k个内容进行排序返回内容,

>>>x = torch.tensor([[10, 20, 30],
                     [15, 25, 35],
                     [12, 22, 32],
                     [8, 18, 28]])
>>>values, indices = torch.topk(x, k=2, largest=False, dim=0)
tensor([[ 8, 18, 28],
        [10, 20, 30]])
tensor([[3, 3, 3],
        [0, 0, 0]])  #相同列,返回两个值,注意其中相同列返回的内容是并行的,是放入到两者相互对应的维度上吧!!!

到这儿基础知识应该介绍的差不多了吧~
在这里插入图片描述

三、代码手撕

min_distances = torch.cat([torch.squeeze(
        torch.norm(kp[i].unsqueeze(1) - kp[i].unsqueeze(0), dim=2, p=None).topk(2, largest=False, dim=0)[
            0]) for i in range(len(kp))], dim=0)

关于这里,我们就以这组数据来进行一步步的运算,将其串起来吧~

import torch

# 设置随机种子
torch.manual_seed(42)

x = torch.randint(low=0, high=10, size=(5, 3))
print(x)

>>>tensor([[2, 7, 6],
           [4, 6, 5],
           [0, 4, 0],
           [3, 8, 4],
           [0, 4, 1]]) #维度为5,3

生成完数据之后,我们先把torch.norm括号内的kp[i].unsqueeze(1) - kp[i].unsqueeze(0)描绘出来:

kp_1 = x.unsqueeze(1)
>>>tensor([[[2, 7, 6]],
           [[4, 6, 5]],
           [[0, 4, 0]],
           [[3, 8, 4]],
           [[0, 4, 1]]]) #维度为(5,1,3)这里在第二个轴新加一个维度为1的轴

kp_2 = x.unsqueeze(0)
>>>tensor([[[2, 7, 6],
            [4, 6, 5],
            [0, 4, 0],
            [3, 8, 4],
            [0, 4, 1]]]) #维度为(1,5,3)这里在第一个轴新加一个维度为1的轴

关于kp[i].unsqueeze(1) - kp[i].unsqueeze(0)的话,我们其实先理解为kp_1[i]-kp_2[i]的操作即可,因为后面有个循环,那么本质就是类似kp_1 - kp_2这个算法,这时可能有小白觉着这里维度都对应不上,怎么能进行减法操作呢?其实这里用到了三维上的广播机制进行运算,并且是从第三维度直接进行欧几里得距离运算。

这里先展示一下这个三维的广播机制的形式:

n = kp_1 - kp_2
>>>tensor([[[ 0,  0,  0],
            [-2,  1,  1],
            [ 2,  3,  6],
            [-1, -1,  2],
            [ 2,  3,  5]],

        [[ 2, -1, -1],
            [ 0,  0,  0],
            [ 4,  2,  5],
            [ 1, -2,  1],
            [ 4,  2,  4]],

        [[-2, -3, -6],
            [-4, -2, -5],
            [ 0,  0,  0],
            [-3, -4, -4],
            [ 0,  0, -1]],

        [[ 1,  1, -2],
            [-1,  2, -1],
            [ 3,  4,  4],
            [ 0,  0,  0],
            [ 3,  4,  3]],

        [[-2, -3, -5],
            [-4, -2, -4],
            [ 0,  0,  1],
            [-3, -4, -3],
            [ 0,  0,  0]]])

这是计算结果;我们这就以n[0,:,:]的计算结果块来展示kp_1,kp_2各自在广播机制运算下运算这个结果块的展开规律,其他块也就类似推理:
n[0,:,:]就是[[ 0,  0,  0],
            [-2,  1,  1],
            [ 2,  3,  6],
            [-1, -1,  2],
            [ 2,  3,  5]],这块内容

kp_1[0,:,:] = [[2, 7, 6],              kp_2[0,:,:] = [[2, 7, 6],
               [2, 7, 6],                             [4, 6, 5],
               [2, 7, 6],                             [0, 4, 0],
               [2, 7, 6],                             [3, 8, 4],
               [2, 7, 6]]                             [0, 4, 1]]
这两块差不多就这个意思,如果不是[0,:,:]这个表达的话,也如下这个意思,只是后续的补充内容就懒得敲了,因为是手推的呜呜呜
kp_1[,:,:] = [[[2, 7, 6],              kp_2[0,:,:] = [[[2, 7, 6],
                 [2, 7, 6],                              [4, 6, 5],
                 [2, 7, 6],                              [0, 4, 0],
                 [2, 7, 6],                              [3, 8, 4],
                 [2, 7, 6]],                             [0, 4, 1]],
                [[a, xi, ba],                           [[xi, ba, pro],
                 ...         ]]]                           ...        ]]]
如果不清楚这个广播机制的话,这个地方可以好好和原始的kp_1和kp_2好好琢磨一下

那么很好,上面的内容都基本上清楚了,我们就来计算下torch.norm(kp_1-kp_2, dim=2, p=None)的结果,注意我们这里是沿着第三个轴!!!(dim=2)

n = n.float() #因为我上面的数据类型都是int,即long类型,代码会报错,linalg.norm只接受float或复数类型
u = torch.linalg.norm(n, dim=2, ord=2)  # `ord=2` 指二范数
#哈哈笑死,我现在pytorch版本过高,官网把torch.norm分解掉了,现在用精度更高的linalg.norm进行替换了~没事,这个功能实现的是一样的

>>>tensor([[0.0000, 2.4495, 7.0000, 2.4495, 6.1644],
           [2.4495, 0.0000, 6.7082, 2.4495, 6.0000],
           [7.0000, 6.7082, 0.0000, 6.4031, 1.0000],
           [2.4495, 2.4495, 6.4031, 0.0000, 5.8310],
           [6.1644, 6.0000, 1.0000, 5.8310, 0.0000]])

现在阔以好好看看这个结果咯!!!
以n[0,:,:]为例吧,他的计算结果就是u[0,:]
n[0,:,:]= [[[ 0,  0,  0],
            [-2,  1,  1],
            [ 2,  3,  6],
            [-1, -1,  2],
            [ 2,  3,  5]]]
那么sqrt(0**2 + 0**2 + 0**2 = 0      sqrt((-2)**2 + 1**2 + 1**2) = 2.4495  sqrt(2**2 + 3**2 + 6**2) = 7
sqrt((-1)**2 + (-1)**2 + 2**2) = 2.4495    sqrt(2**2 + 3**2 + 5**2) = 6.1644

那么u这个5*5的张量而言,u[i][j]就表示第i个点与第j个点之间的欧氏空间距离!!!
好好想想这块,想不通就手写一遍,就能理解了!

在这里插入图片描述

兄弟men别急,别忘了后面还紧跟着个topk哦~~~

u = tensor([[0.0000, 2.4495, 7.0000, 2.4495, 6.1644],
            [2.4495, 0.0000, 6.7082, 2.4495, 6.0000],
            [7.0000, 6.7082, 0.0000, 6.4031, 1.0000],
            [2.4495, 2.4495, 6.4031, 0.0000, 5.8310],
            [6.1644, 6.0000, 1.0000, 5.8310, 0.0000]])

那么这个topk的运算本质上就是在这组张量进行操作了~
原始代码的执行本质上就等价于:  u.topk(2, largest=False, dim=0)[0]  的运算!
为啥后面会跟着个[0]呢?因为topk返回的是两组对象,还记得第二大节那个基础函数命令解析吗?所以我们只提取第一组,
也就是距离这个数据,而不需要那个位置的索引!!!

再来分析,dim=0,我们这里是按行为方向,从列进行比较!!! 又由于largest=False,表示返回的是最小距离,不是最大值!!!

u.topk(2, largest=False, dim=0)[0]
>>>tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
           [2.4495, 2.4495, 1.0000, 2.4495, 1.0000]]) #这里全为0的信息从上面的u数值可以直接得知:他是本身与本身的距离;
#这就是为啥作者只选前2个了,也就是u.topk(2, largest=False, dim=0)[0]这里的2
#这里你得头脑清醒,我们还有一个最重要的内容:kNN是要计算距离ki最近的别的关键点坐标,然后来计算欧氏距离
#哈哈哈我上面是一个思路,可作者在这里实现的思路是先将所有点之间的距离都给计算一遍,然后直接取距离最小的那个关键点对应的距离
#并且作者是以一个对称矩阵进行的实现,真的是,妙啊啊啊啊啊啊,不愧是多所高校以及世界顶流大企业联合发表的呜呜呜

讲到这,我们就得回顾原始代码了,注意我们这边是一个列表推导,而不是完整的kp_1-kp_2,所以最后的最后,作者是通过如下列表推导式,并借助squeeze进行提取到当前k_i这个关键点与最近关键点间的最小距离的!!!!!!

torch.squeeze(torch.norm(kp_1[i]-kp_2[i], dim=2, p=None).top(2,largest=False,dim=0)[0] for i in range(len(kp))
关于这行代码,我们就以i=0进行可视化过程,i=1...N的话,基本过程和i=0是一致的;

i=0时
kp_1[0] = [[[2, 7, 6]]]             kp_2[0] = [[[2, 7, 6],
                                                [4, 6, 5],
                                                [0, 4, 0],
                                                [3, 8, 4],
                                                [0, 4, 1]]]    
 #要是搞不清楚这咋来的,可以从头仔细看起,整篇逻辑应该还算清晰

kp_1[0] - kp_2[0] = [[[2, 7, 6],        -     [[[2, 7, 6],           =       [[[0, 0, 0],
                      [2, 7, 6],                [4, 6, 5],                     [-2, 1, 1],
                      [2, 7, 6],                [0, 4, 0],                     [2, 3, 6],
                      [2, 7, 6],                [3, 8, 4],                     [-1, -1, 2],
                      [2, 7, 6]]]               [0, 4, 1]]] #广播机制          [2, 3, 5]]]

那么torch.norm(kp_1[0] - kp_2[0], dim=2, p=None)的计算结果就和上面我计算过的一样:
n[0,:,:]= [[[ 0,  0,  0],
            [-2,  1,  1],
            [ 2,  3,  6],
            [-1, -1,  2],
            [ 2,  3,  5]]]
那么sqrt(0**2 + 0**2 + 0**2 = 0      sqrt((-2)**2 + 1**2 + 1**2) = 2.4495  sqrt(2**2 + 3**2 + 6**2) = 7
sqrt((-1)**2 + (-1)**2 + 2**2) = 2.4495    sqrt(2**2 + 3**2 + 5**2) = 6.1644

所以torch.norm(kp_1[0] - kp_2[0], dim=2, p=None) = [[0, 2.4495, 7, 2.4495, 6.1644]] #因为是沿着dim=2轴进行计算的哦!!!
要理清这里的现实含义:这个代表第1个点与其本身以及其他4个关键点坐标的欧几里得距离;

随后通过[[0, 2.4495, 7, 2.4495, 6.1644]].top(2,largest=False,dim=0)[0]进行计算:
return [[0, 2.4495]] #我们这里完全不关心最近关键点是哪个,只关心他们之间的欧几里得距离,所以就算有两个关键点的距离一样,也不影响

那么之后通过 torch.squeeze([[0, 2.4495]])进行维度为1的轴剔除操作
return [0, 2.4495] 

那么最后!就是torch.cat(xx, dim=0)这一个工作了!!! #这个xx在这个局部含义内就表示为如下u返回的两个最小的距离列表
u = tensor([[0.0000, 2.4495, 7.0000, 2.4495, 6.1644],
            [2.4495, 0.0000, 6.7082, 2.4495, 6.0000],
            [7.0000, 6.7082, 0.0000, 6.4031, 1.0000],
            [2.4495, 2.4495, 6.4031, 0.0000, 5.8310],
            [6.1644, 6.0000, 1.0000, 5.8310, 0.0000]])
依据这个距离计算矩阵,我们只要将每行中返回的对应前两个最小距离进行cat操作即可:
u[0]' return [0, 2.4495], u[1]' return [0, 2.4495]    #均结果torch.squeeze操作处理
u[2]' return  [0, 1.0000], u[3]' return [0, 2.4495]  u[4]' return [0, 1.0000]

那么通过torch.cat((u[0]',u[1]', u[2]', u[3]', u[4]'), dim=0)以行按列进行cat:
>>> [[0, 2.4495],
     [0, 2.4495],
     [0, 1.0000],
     [0, 2.4495],
     [0, 1.0000]]  = min_distances

最后就是结果返回的运算程序解析咯~

return 1/torch.mean(min_distances[min_distances>0])

其中min_distances[min_distances>0])的话,[min_distances>0]是借助逻辑运算符来得出大于0的位置进而提取出对应内容
min_distances > 0.0
>>>tensor([[False,  True],
           [False,  True],
           [False,  True],
           [False,  True],
           [False,  True]])

min_distances[min_distances > 0.0]
>>>tensor([2.4495, 2.4495, 1.0000, 2.4495, 1.0000])

最后通过torch.mean(min_distances[min_distances > 0.0])计算,即可最终解决上述运算公式!!!
不过还有个小疑问就是,这行代码并没有执行他对应的max(xx, 0.01)这个函数,可能是作者知道他们关键点之间的距离的最小距离
不会低于0.01~

在这里插入图片描述

四、总结

通过这简单的主要4行代码,我码了近1.4w字,来从头手撕他内部运行时所需掌握的每一步内容,其中尤为涉及到张量轴的操作,这在我们处理图像数据中也是尤为重要的一点,而且论文作者在写这个简洁代码,我看得出他们对该公式的思考深度是独特的,肯定衡量了代码简洁性、运算的不同表达方式、简单又快捷的逻辑运算同时还达到计算要求的目的,太吓人了!!!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值