文章目录
本文为PointNet++ CUDA代码阅读系列的第四部分,其他详见:
(一)PointNet++代码梳理
(二)PointNet++中的FPS的CUDA实现
(三)PointNet++中ball query的CUDA实现
(四)PointNet++中的Three_nn的CUDA实现
给定点集known和unknown,Three_nn实现的功能是对于unknown的每个点,找到其在known中最临近的3个点的距离和下标,直接看cu代码,在src/interpolate_gpu.cu中:
__global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown,
const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
int bs_idx = blockIdx.y; // 找到这个线程处理的batch
int pt_idx = blockIdx.x * blockDim