PointNet++在处理分割任务的时候需要将下采样的点还原到与输入相同的点数,便于做每个点的预测。但是在论文中只给了一个简单的描述和公式,不是很好理解,因此在这里记录一下我的理解过程。
1.FP模块的目的
PointNet++会随着网络逐层降低采样的点数,这样来保证网络获得足够的全局信息,但是这样会导致无法完成分割任务,因为分割任务是一个端到端的,必须保证输出与输入点数相同。
一种完成分割任务的方法就是不下采样点,始终将所有点放入网络进行计算。但这样需要消耗大量计算成本。另一种比较常用的方法就是进行插值了,利用已知点,来补足需要点。
FP模块的目的就是利用已知的特征点来进行插值,使网络输出与输入点数相同的特征。具体的做法可见下一步。
2.如何插值
对于如何插值,在2d图像中是根据该像素点相邻的点来确定插值点的像素值。这种插值方式相当于找了最近距离的像素点来确定被插点的像素值。对应在点云中,也可以根据距离来对特征进行插值。
具体地,可以见作者的描述。对于NL层的深层特征,假设它有128个点,每个点有d+C个特征;它的前一层,也就NL-1层,假设有512个点。其中NL-1的点数是大于NL的点数,因为我们之前会进行下采样,点数会被压缩。
现在,我们已知NL-1层和NL层的特征和坐标,目的就是将128个点,上采样到512个点。如何实现呢?答案是利用距离的来插值(特别要注意区分特征和坐标在其中的意义)。具体步骤如下:
1)我们利用NL-1层和NL层的坐标计算任意两个点之间的距离,那么我们会得到一个距离矩阵,
它的尺寸为512x128。它的含义就是NL-1(低维)中的每个点与NL(高维)中的每个点的距离。
2)将距离矩阵进行排序,找到NL层中与NL-1层中距离最近的三个点,并记录它们值和索引,标记为dist和idx,注意,此时的索引和距离矩阵的尺寸是512x3,也就是NL-1中与NL距离最近的三个点的索引。
3)将idx拿着去NL层的特征中进行查询,得到512x3x(c+d)的特征矩阵。这里比较不好理解,为什么这一步能够将128上采样到512,其实这里会进行重复采样,因为前面得到的idx矩阵是512x3,这个里面的每一行的三个元素都是在128个点中的索引。例如,第一行的值可能是 2 3 4 第二行的值可能是3 4 5 …如此进行重复采样。
4)前面我们已经将特征进行上采样,但他们本质还是原来128个点对应的特征,如果不进行变换,那么这个上采样将毫无意义。所以,我们需要按前面距离矩阵进行插值,来改变特征的值。前面我们在dist矩阵里面保留的就是我们距离值,因此,我们只需要将它与特征进行扩维相乘就可以了(因为特征维数要大于距离维数),也就等同于加权了。
以上就是整个FP模块在做的事情,如果没有看懂,接下来通过debug代码再次分析一下整个过程:
class PointNetFeaturePropagation(nn.Module):
# in_channel=1280, mlp=[256, 256]
def __init__(self, in_channel, mlp):
super(PointNetFeaturePropagation, self).__init__()
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
last_channel = out_channel
def forward(self, xyz1, xyz2, points1, points2):
# 前面两层的质心和前面两层的输出
"""
Input:
利用前一层的点对后面的点进行插值
xyz1: input points position data, [B, C, N] l2层输出 xyz
xyz2: sampled input points position data, [B, C, S] l3层输出 xyz
points1: input points data, [B, D, N] l2层输出 points
points2: input points data, [B, D, S] l3层输出 points
Return:
new_points: upsampled points data, [B, D', N]
"""
" 将B C N 转换为B N C 然后利用插值将高维点云数目S 插值到低维点云数目N (N大于S)"
" xyz1 低维点云 数量为N xyz2 高维点云 数量为S"
xyz1 = xyz1.permute(0, 2, 1) # 第一次插值时 2,3,128 ---> 2,128,3 | 第二次插值时 2,3,512--->2,512,3
xyz2 = xyz2.permute(0, 2, 1) # 第一次插值时2,3,1 ---> 2 ,1,3 | 第二次插值时 2,3,128--->2,128,3
points2 = points2.permute(0, 2, 1)# 第一次插值时2,1021,1 --->2,1,1024 最后低维信息,压缩成一个点了 这个点有1024个特征
# 第二次插值 2,256,128 --->2,128,256
B, N, C = xyz1.shape # N = 128 低维特征的点云数 (其数量大于高维特征)
_, S, _ = xyz2.shape # s = 1 高维特征的点云数
if S == 1:
"如果最后只有一个点,就将S直复制N份后与与低维信息进行拼接"
interpolated_points = points2.repeat(1, N, 1) # 2,128,1024 第一次直接用拼接代替插值
else:
"如果不是一个点 则插值放大 128个点---->512个点"
"此时计算出的距离是一个矩阵 512x128 也就是512个低维点与128个高维点 两两之间的距离"
dists = square_distance(xyz1, xyz2) # 第二次插值 先计算高维与低维的距离 2,512,128
dists, idx = dists.sort(dim=-1) # 2,512,128 在最后一个维度进行排序 默认进行升序排序,也就是越靠前的位置说明 xyz1离xyz2距离较近
"找到距离最近的三个邻居,这里的idx:2,512,3的含义就是512个点与128个距离最近的前三个点的索引," \
"例如第一行就是:对应128个点中那三个与512中第一个点距离最近"
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 2,512,3 此时dist里面存放的就是 xyz1离xyz2最近的3个点的距离
dist_recip = 1.0 / (dists + 1e-8) # 求距离的倒数 2,512,3 对应论文中的 Wi(x)
"对dist_recip的倒数求和 torch.sum keepdim=True 保留求和后的维度 2,512,1"
norm = torch.sum(dist_recip, dim=2, keepdim=True) # 也就是将距离最近的三个邻居的加起来 此时对应论文中公式的分母部分
weight = dist_recip / norm # 2,512,3
"""
这里的weight是计算权重 dist_recip中存放的是三个邻居的距离 norm中存放是距离的和
两者相除就是每个距离占总和的比重 也就是weight
"""
t = index_points(points2, idx) # 2,512,3,256
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
"""
points2: 2,128,256 (128个点 256个特征) idx 2,512,3 (512个点中与128个点距离最近的三个点的索引)
index_points(points2, idx) 从高维特征(128个点)中找到对应低维特征(512个点) 对应距离最小的三个点的特征 2,512,3,256
这个索引的含义比较重要,可以再看一下idx参数的解释,其实2,512,3,256中的512都是高维特征128个点组成的。
例如 512中的第一个点 可能是由128中的第 1 2 3 组成的;第二个点可能是由2 3 4 三个点组成的
-------------------------------------------
weight: 2,512,3 weight.view(B, N, 3, 1) ---> 2,512,3,1
a与b做*乘法,原则是如果a与b的size不同,则以某种方式将a或b进行复制,使得复制后的a和b的size相同,然后再将a和b做element-wise的乘法。
这样做乘法就相当于 512,3,256 中的三个点的256维向量都会去乘它们的距离权重,也就是一个数去乘256维向量
torch.sum dim=2 最后在第二个维度求和 取三个点的特征乘以权重后的和 也就完成了对特征点的上采样
"""
if points1 is not None:
points1 = points1.permute(0, 2, 1) # 2,256,128 -->2,128,256
new_points = torch.cat([points1, interpolated_points], dim=-1) # 2,128,1280
else:
new_points = interpolated_points
new_points = new_points.permute(0, 2, 1)
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
return new_points
我们只需着重看它的上采样部分,对于它的卷积部分这里就不作详细分析了。
大致流程基本上和作者描述的公式相同,这里着重讲解一些代码中这个weight变量,之前我觉得它一直和论文的公式对不上,后来我发现要论文的公式拆开看就能对上了。
这里的weight是计算权重 dist_recip中存放的是三个邻居的距离 norm中存放是距离的和
两者相除就是每个距离占总和的比重 也就是weight