PyTorch 版的完整可执行代码,包含:
- PointNet(原始网络)
- PointNet++(采用分层局部聚合:Farthest Point Sampling → Ball Query → PointNet 层)
- 数据读取、训练与验证
- 所有关键步骤均做了详细注释,并使用
torch.nn.functional的高效实现
代码已在 PyTorch 1.13 + CUDA 11.x 环境下通过单机实验。
如需在 CPU 上跑,只需要把device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')改成'cpu'。
# pointnet_pointnetpp.py
import os
import math
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
#############################
# 1. 数据集(ModelNet40 示例)
#########
订阅专栏 解锁全文
103

被折叠的 条评论
为什么被折叠?



