最近在看PointCONV,发现它在结构上借鉴并修改了Pointnet++的设计,之前偏懒没有仔细去看Pointnet++,现在只能回去看。
Pointnet++整体结构:
首先先看encoder部分,其中包含了两个SetAbstraction,每个SetAbstraction由三部分组成,包括采样层、分组层和Pointnet层。在图中画出了两个SetAbstraction模块,但是在代码中,具体任务作者使用了3~4个SetAbstraction模块,这是可以理解的,毕竟理论上越多,提取的特征越细致。
采样层使用的是FPS,这个网上有很多资料,就不详细说了,然后是分组,这也没什么好说的,使用的是query_ball,对于分布不均匀的点云,作者又采用的技术建议一起看这里。下面上代码:
其中,point代表了包含xyz在内的特征,例如RGB或者法向量。xyz是只包含xyz特征的点云。
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
# new_xyz: sampled points position data, [B, npoint, C]
# new_points: sampled points data, [B, npoint, nsample, C+D]
new_points = new_points.permute(0, 3, 2, 1