Pointnet++代码架构理解

PointNet++(或简称 PointNet2)是 PointNet 的改进版本,专门用于处理更复杂的点云数据,它在 PointNet 的基础上引入了局部特征学习和层次化结构,使得网络能够捕捉点云数据中的局部结构,并且在处理大规模点云时更加高效。

下面是对 PointNet++ GitHub 源代码架构的详细讲解。PointNet++的实现是基于 PyTorch 和 TensorFlow 等深度学习框架,以下主要以 GitHub 上的 PointNet++ PyTorch 实现为例(代码路径和文件名称可能会有所不同,但大体结构是一致的)。

源码链接:

GitHub - erikwijmans/Pointnet2_PyTorch: PyTorch implementation of Pointnet2/Pointnet++

1. PointNet++ 源代码文件夹结构

典型的 PointNet++ GitHub 仓库的文件夹结构如下:

PointNet2/
├── models/
│   ├── pointnet2_classification.py
│   ├── pointnet2_segmentation.py
│   ├── pointnet2_utils.py
├── data/
│   ├── ModelNet40/
│   ├── ScanNet/
│   ├── prepare_data.py
│   ├── data_utils.py
├── utils/
│   ├── pytorch_utils.py
│   ├── knn.py
│   ├── log.py
│   ├── metrics.py
├── train.py
├── test.py
├── config.py
└── requirements.txt

2. 重要文件和功能讲解

2.1 models/

这个文件夹包含了 PointNet++ 网络的主要实现代码,通常会有以下几种模块:

  • pointnet2_classification.py

    • 这个文件包含了用于分类任务的 PointNet++ 模型的定义。
    • PointNet++ 中,网络的核心部分是通过多层次的 Set Abstraction (SA) 来抽象局部特征,之后通过全局特征池化来获得最终的分类结果。
  • pointnet2_segmentation.py

    • 用于点云分割任务的实现。与分类任务不同,分割任务要求每个点都有一个标签,因此模型的输出通常是每个点的类别(而非整个点云的类别)。
  • pointnet2_utils.py

    • 存放了一些用于 PointNet++ 网络的实用工具函数,例如,点云采样、分组、邻域查询等。
2.2 data/

这个文件夹包含了与数据集相关的代码,通常包括以下几个文件:

  • ModelNet40/ScanNet/

    • 存放数据集的文件夹,分别包含了用于点云分类的 ModelNet40 数据集和用于语义分割的 ScanNet 数据集。
    • 这些文件夹下通常会有一些预处理后的点云数据(例如,转换为 .ply.txt 格式)。
  • prepare_data.py

    • 这个文件用于预处理数据集。比如将点云数据转换为适合模型输入的格式,或者进行数据增强操作。
  • data_utils.py

    • 提供了一些数据加载和处理的工具函数,如数据批次生成、点云数据的归一化、随机旋转等。
2.3 utils/

这个文件夹包含了一些辅助性工具,例如:

  • pytorch_utils.py

    • 包含一些 PyTorch 的实用工具函数,如权重初始化、模型保存和加载等。
  • knn.py

    • 这个文件实现了 k-近邻(KNN)算法。在 PointNet++ 中,KNN 是用来查询点云中各个点的邻居,从而进行局部特征提取。
  • log.py

    • 用于记录训练过程中的日志,输出训练进度、损失函数等信息。
  • metrics.py

    • 提供一些评估指标,如分类精度、IoU(Intersection over Union)等。
2.4 train.pytest.py
  • train.py

    • 这是模型的训练入口,包含了训练过程中的主要逻辑。通常,它会定义训练数据加载器、网络模型、损失函数、优化器等。它还会指定一些超参数,比如学习率、批次大小、训练轮数等。
  • test.py

    • 用于评估训练好的模型,通常包含对模型的测试、预测结果的保存和评估等功能。
2.5 config.py
  • 这个文件用于配置超参数和模型参数。它通常会列出模型结构的各个参数(例如层数、每层的点数、使用的网络模块等),以及训练过程中的参数(学习率、批次大小等)。
2.6 requirements.txt
  • 这个文件列出了运行 PointNet++ 所需的 Python 库和依赖。通过 pip install -r requirements.txt 可以安装所有的依赖库。

3. 核心模块解读

3.1 Set Abstraction (SA) 层

Set Abstraction(SA)是 PointNet++ 的关键模块,负责对输入点云进行局部特征抽象。它通常包括以下几个步骤:

  1. Farthest Point Sampling (FPS)

    • 通过 FPS 方法,逐步选择代表性的点,从而减少点云的规模,且这些点的分布尽可能均匀。
  2. KNN 或 Ball Query

    • 使用 KNN(k-近邻)或 Ball Query(球查询)方法找到每个点的局部邻域。
  3. MLP 和 Pooling 操作

    • 对邻域内的点使用 MLP 网络进行特征学习,并使用池化操作(如最大池化)来汇聚局部特征。
  4. 全局特征融合

    • 在每一层的 SA 后,进行全局特征池化(Global Pooling),使得模型可以捕捉到整个点云的全局信息。
3.2 逐层处理和特征聚合

PointNet++ 的网络结构通常是一个由多个 SA 层组成的堆叠结构。每一层 SA 层负责从上层抽取的特征中进一步提取更高阶的局部特征,直到网络达到足够的深度。

3.3 分类和分割头
  • 分类头:通常是一个简单的 MLP 层,通过最大池化将全局特征聚合成一个固定长度的向量,然后通过一个全连接层输出最终的类别预测。

  • 分割头:与分类头类似,不同之处在于每个点都有一个单独的预测标签。因此,输出层通常是一个点数大小的向量,每个元素代表一个点的类别。

4. 训练和评估过程

  • 训练

    • train.py 中,数据加载器会从磁盘加载数据,并进行数据增强(如随机旋转、平移等),以增强模型的鲁棒性。
    • 模型会使用优化器(如 Adam)和损失函数(如交叉熵损失或欧几里得损失)进行训练。
  • 评估

    • 训练完成后,在 test.py 中评估模型的表现,通常会计算精度、IoU 等指标。

train.py解析

import os

import hydra
import omegaconf
import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers import TensorBoardLogger

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True


def hydra_params_to_dotdict(hparams):
    def _to_dot_dict(cfg):
        res = {}
        for k, v in cfg.items():
            if isinstance(v, omegaconf.DictConfig):
                res.update(
                    {k + "." + subk: subv for subk, subv in _to_dot_dict(v).items()}
                )
            elif isinstance(v, (str, int, float, bool)):
                res[k] = v

        return res

    return _to_dot_dict(hparams)


@hydra.main("config/config.yaml")
def main(cfg):
    model = hydra.utils.instantiate(cfg.task_model, hydra_params_to_dotdict(cfg))

    early_stop_callback = pl.callbacks.EarlyStopping(patience=5)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor="val_acc",
        mode="max",
        save_top_k=2,
        filepath=os.path.join(
            cfg.task_model.name, "{epoch}-{val_loss:.2f}-{val_acc:.3f}"
        ),
        verbose=True,
    )
    trainer = pl.Trainer(
        gpus=list(cfg.gpus),
        max_epochs=cfg.epochs,
        early_stop_callback=early_stop_callback,
        checkpoint_callback=checkpoint_callback,
        distributed_backend=cfg.distrib_backend,
    )

    trainer.fit(model)


if __name__ == "__main__":
    main()

这段代码是一个基于 PyTorch LightningHydra 的训练脚本,主要用于配置、训练模型,并支持多种训练参数的灵活配置。从各个部分详细讲解这段代码:

1. 导入必要的库

python

import os
import hydra
import omegaconf
import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers import TensorBoardLogger
  • os:用于文件操作,特别是文件路径相关的操作。
  • hydra:用于简化配置文件的管理,支持通过配置文件传递超参数。
  • omegaconf:是 Hydra 底层配置库,用于读取和操作配置文件。
  • pytorch_lightning as pl:PyTorch Lightning 是一个高层次的 API,用于简化 PyTorch 的训练和模型管理,它帮助简化训练循环和模型的部署。
  • torch:PyTorch 的基础库,用于深度学习模型的构建和训练。
  • TensorBoardLogger:用于将训练日志记录到 TensorBoard。

2. 优化 cudnn 设置

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
  • cudnn.enabled = True:启用 cuDNN 后端(NVIDIA CUDA 深度神经网络库)以加速运算。
  • cudnn.benchmark = True:当输入数据的大小和形状是固定时,启用 cuDNN 的自动优化算法选择功能,可以提高训练速度。

3. 将 Hydra 配置转换为 dotdict

def hydra_params_to_dotdict(hparams):
    def _to_dot_dict(cfg):
        res = {}
        for k, v in cfg.items():
            if isinstance(v, omegaconf.DictConfig):
                res.update(
                    {k + "." + subk: subv for subk, subv in _to_dot_dict(v).items()}
                )
            elif isinstance(v, (str, int, float, bool)):
                res[k] = v
        return res
    return _to_dot_dict(hparams)
  • hydra_params_to_dotdict 是一个将 Hydra 配置对象(通常是 omegaconf.DictConfig 类型)转换为 Python 字典的函数,字典的键使用“点号”形式(例如 task_model.name)。
  • 这种转换可以方便后续在代码中访问配置参数,避免层级嵌套带来的访问不便。

4. 主函数:训练过程

@hydra.main("config/config.yaml")
def main(cfg):
    model = hydra.utils.instantiate(cfg.task_model, hydra_params_to_dotdict(cfg))
  • @hydra.main("config/config.yaml"):这是 Hydra 的装饰器,表示 main 函数是训练脚本的入口点,config/config.yaml 是默认的配置文件路径。Hydra 会在运行时加载并解析这个 YAML 配置文件中的内容,并将配置参数传递给 cfg
  • cfg 是 Hydra 加载的配置对象,它通常是一个字典类型,包含了整个项目的所有配置(如模型参数、训练参数、设备设置等)。
  • hydra.utils.instantiate(cfg.task_model, hydra_params_to_dotdict(cfg)):使用 cfg.task_model 配置参数实例化模型。cfg.task_model 应该是一个模型类的配置字典,而 hydra_params_to_dotdict(cfg) 则将配置信息转化为 Python 字典,传递给模型初始化函数。

5. 设置回调函数:早停与模型检查点

early_stop_callback = pl.callbacks.EarlyStopping(patience=5)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor="val_acc",
    mode="max",
    save_top_k=2,
    filepath=os.path.join(
        cfg.task_model.name, "{epoch}-{val_loss:.2f}-{val_acc:.3f}"
    ),
    verbose=True,
)
  • early_stop_callback:用于设置早停策略。如果验证集的表现(如 val_acc)在指定的 patience 轮内没有提升,则停止训练。
  • checkpoint_callback:用于保存最佳模型,基于监控的指标(这里是 val_acc),保存表现最好的 k 个模型。
    • monitor="val_acc":指定监控的指标为验证集上的准确率。
    • mode="max":表明我们希望 val_acc 越大越好。
    • save_top_k=2:保留前 2 个表现最好的模型。
    • filepath:模型保存的路径,其中包含了 epoch、val_loss、val_acc 等信息。
    • verbose=True:打印模型保存的详细信息。

6. 配置训练器并启动训练

trainer = pl.Trainer(
    gpus=list(cfg.gpus),
    max_epochs=cfg.epochs,
    early_stop_callback=early_stop_callback,
    checkpoint_callback=checkpoint_callback,
    distributed_backend=cfg.distrib_backend,
)
trainer.fit(model)
  • gpus=list(cfg.gpus):设置使用的 GPU 设备,cfg.gpus 是一个列表,可以是单个 GPU,也可以是多个 GPU。
  • max_epochs=cfg.epochs:设置训练的最大轮次。
  • early_stop_callback=early_stop_callback:传入早停回调函数。
  • checkpoint_callback=checkpoint_callback:传入模型检查点回调函数。
  • distributed_backend=cfg.distrib_backend:指定分布式训练的后端,可以是 ddp(分布式数据并行)等。
  • trainer.fit(model):启动训练过程,传入实例化后的模型,开始训练。

7. 程序入口

if __name__ == "__main__":
    main()
  • 这部分确保在直接运行这个脚本时,main() 函数会被调用。

总结

这段代码是一个典型的 PyTorch Lightning + Hydra 项目中训练脚本的实现:

  • Hydra 用于动态加载配置文件,支持非常灵活的超参数管理和命令行修改。
  • PyTorch Lightning 负责简化训练过程,包括模型定义、训练循环、回调、日志等。
  • 使用了 早停模型检查点 回调,确保训练能够在最优状态下停止,并保存最佳模型

在代码中,cfg.task_model 是通过 hydra.utils.instantiate(cfg.task_model, hydra_params_to_dotdict(cfg)) 来实例化任务模型的。cfg.task_model 在 Hydra 配置文件中定义,它的具体内容决定了任务和模型的类型。任务模型可能包含分类(classification)模型、分割(segmentation)模型等。

Hydra支持命令行传入修改yaml的参数

python pointnet2/train.py task=cls gpus=[0,1,2,3]

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值