【代码阅读】PointNet++中的Three_nn的CUDA实现

文章目录

本文为PointNet++ CUDA代码阅读系列的第四部分,其他详见:
(一)PointNet++代码梳理
(二)PointNet++中的FPS的CUDA实现
(三)PointNet++中ball query的CUDA实现
(四)PointNet++中的Three_nn的CUDA实现


给定点集known和unknown,Three_nn实现的功能是对于unknown的每个点,找到其在known中最临近的3个点的距离和下标,直接看cu代码,在src/interpolate_gpu.cu中:

__global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, 
    const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
   
   
    // unknown: (B, N, 3)
    // known: (B, M, 3)
    // output: 
    //      dist2: (B, N, 3)
    //      idx: (B, N, 3)
    
    int bs_idx = blockIdx.y;  // 找到这个线程处理的batch
    int pt_idx = blockIdx.x * blockDim
### PointNet++在PyTorch中的实现代码 PointNet++ 是一种基于点云数据的深度学习架构,用于三维分类和分割任务。其 PyTorch 实现可以在多个开源项目中找到,以下是一个具体的实现示例[^1]。 ```python import torch import torch.nn as nn import torch.nn.functional as F class PointNetSetAbstraction(nn.Module): def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): super(PointNetSetAbstraction, self).__init__() self.npoint = npoint self.radius = radius self.nsample = nsample self.mlp_convs = nn.ModuleList() self.mlp_bns = nn.ModuleList() last_channel = in_channel for out_channel in mlp: self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) self.mlp_bns.append(nn.BatchNorm2d(out_channel)) last_channel = out_channel self.group_all = group_all def forward(self, xyz, points): """ Input: xyz: input points position data, [B, C, N] points: input points data, [B, D, N] Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ # 采样与分组逻辑(省略具体实现) new_xyz = index_points(xyz, farthest_point_sample(xyz, self.npoint)) grouped_xyz, grouped_points = query_and_group(self.npoint, self.radius, self.nsample, xyz, new_xyz, points) # MLP 处理 for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] grouped_points = F.relu(bn(conv(grouped_points))) new_points = torch.max(grouped_points, 2)[0] return new_xyz, new_points class PointNetFeaturePropagation(nn.Module): def __init__(self, in_channel, mlp): super(PointNetFeaturePropagation, self).__init__() self.mlp_convs = nn.ModuleList() self.mlp_bns = nn.ModuleList() last_channel = in_channel for out_channel in mlp: self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) self.mlp_bns.append(nn.BatchNorm1d(out_channel)) last_channel = out_channel def forward(self, xyz1, xyz2, points1, points2): """ Input: xyz1: input points position data, [B, C, N] xyz2: sampled input points position data, [B, C, S] points1: input points data, [B, D, N] points2: input points data, [B, D, S] Return: new_points: upsampled points data, [B, D', N] """ dist, idx = three_nn(xyz1.permute(0, 2, 1), xyz2.permute(0, 2, 1)) dist_recip = 1.0 / (dist + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_points = three_interpolate(points2, idx, weight) if points1 is not None: new_points = torch.cat([points1, interpolated_points], dim=1) else: new_points = interpolated_points for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] new_points = F.relu(bn(conv(new_points))) return new_points ``` 上述代码片段展示了 PointNet++ 的两个核心模块:`PointNetSetAbstraction` 和 `PointNetFeaturePropagation`。这些模块分别实现了点云的分层特征提取和特征传播功能[^2]。 此外,完整的 PointNet++ 实现代码可以在以下 GitHub 仓库中找到: - **PointNet2.PyTorch**: [https://gitcode.com/gh_mirrors/po/Pointnet2.PyTorch](https://gitcode.com/gh_mirrors/po/Pointnet2.PyTorch) [^1] - **Pointnet_Pointnet2_pytorch**: [https://gitcode.com/gh_mirrors/po/Pointnet_Pointnet2_pytorch](https://gitcode.com/gh_mirrors/po/Pointnet_Pointnet2_pytorch) [^2] 这些仓库提供了完整的训练脚本、模型定义以及预训练权重,适合开发者快速上手并进行研究或应用开发。 ###
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值