PPFStructualEmbedding & GeometricStructureEmbedding 模块实现源码精读

  PPFStructureEmbedding 是一种基于 Point Pair Features (PPF) 的几何结构嵌入方法,通常用于点云处理任务中。PPF 是一种描述点云中局部几何关系的特征表示方法,广泛应用于点云配准、物体识别和姿态估计等任务。

  GeometricStructureEmbedding 类是一个 PyTorch 模块,旨在将点云中的几何结构嵌入到高维空间中。该嵌入捕捉了点之间的成对距离和角度关系,将superpoint之间的距离和角度信息引入到self-attention的计算中,这对于许多 3D 视觉任务(如点云分类、分割或目标检测)至关重要。Geometric Structure Embedding的核心思想是利用超点在同一场景不同点云中的距离和角度应当是一致的。

        注意这里PPFStructualEmbeddingGeometricStructureEmbedding的区别,

PPFStructualEmbedding的输入是PPF特征,而GeometricStructureEmbedding模块的输入是点云,GeometricStructureEmbedding中另外集成了提取点云特征的接口。

        

角度嵌入示意图

# Reference: https://github.com/qinzheng93/GeoTransformer

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


# 用于计算两个点云(或特征矩阵)之间的成对距离。该函数支持批量处理,并且可以根据输入数据的格式(通道优先或通道后置)进行调整。
# dist (torch.Tensor): 成对距离矩阵,形状为 (*, N, M),表示第一个点云中每个点到第二个点云中每个点的距离。
def pairwise_distance(
    x: torch.Tensor, y: torch.Tensor, normalized: bool = False, channel_first: bool = False
) -> torch.Tensor:
    r"""Pairwise distance of two (batched) point clouds.
    Args:
        x (Tensor): (*, N, C) or (*, C, N)
        y (Tensor): (*, M, C) or (*, C, M)
        normalized (bool=False): if the points are normalized, we have "x2 = y2 = 1", so "d2 = 2 - 2xy".
        channel_first (bool=False): if True, the points shape is (*, C, N).
    Returns:
        dist: torch.Tensor (*, N, M)
    """
    if channel_first:
        channel_dim = -2
        xy = torch.matmul(x.transpose(-1, -2), y)  # [(*, C, N) -> (*, N, C)] x (*, C, M)
    else:
        channel_dim = -1
        xy = torch.matmul(x, y.transpose(-1, -2))  # (*, N, C) x [(*, M, C) -> (*, C, M)]
    if normalized:
        sq_distances = 2.0 - 2.0 * xy # d^2 = (x-y)^2 = x^2 + y^2 - 2xy = 2 - 2xy 即x,y都是单位球上的点
    else:
        x2 = torch.sum(x ** 2, dim=channel_dim).unsqueeze(-1)  # (*, N, C) or (*, C, N) -> (*, N) -> (*, N, 1)
        y2 = torch.sum(y ** 2, dim=channel_dim).unsqueeze(-2)  # (*, M, C) or (*, C, M) -> (*, M) -> (*, 1, M)
        sq_distances = x2 - 2 * xy + y2
    #  使用 clamp 函数确保所有距离值非负,以防止数值不稳定的情况。
    sq_distances = sq_distances.clamp(min=0.0)
    return sq_distances

# SinusoidalPositionEmbedding用法举例
# 创建示例位置索引
# seq_len = 10
# batch_size = 2
# d_model = 8  # 必须是偶数
# emb_indices = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)  # (batch_size, seq_len) 每一个batch都是一个一维序列
#
# # 创建 SinusoidalPositionalEmbedding 实例并获取嵌入
# pos_embedder = SinusoidalPositionalEmbedding(d_model)
# embeddings = pos_embedder(emb_indices)

class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self, d_model):
        super(SinusoidalPositionalEmbedding, self).__init__()
        # d_model: 位置嵌入的维度。必须是偶数,因为正弦和余弦嵌入各占一半。
        # 在Transformer中,pos embedding 和 input embedding的维度要保持一致,都是d_model。
        if d_model % 2 != 0:
            raise ValueError(f'Sinusoidal positional encoding with odd d_model: {d_model}')
        self.d_model = d_model
        div_indices = torch.arange(0, d_model, 2).float() # 生成一个从 0 到 d_model 的偶数索引数组,表示每个维度的索引。2表示Step,隔两个取一个数。
        # 使用指数函数计算每个位置嵌入维度的除数项。
        div_term = torch.exp(div_indices * (-np.log(10000.0) / d_model)) # 计算位置编码中的除数项,输出为一个张量,表示每个维度的除数项。 与div_indices维度一样。
        self.register_buffer('div_term', div_term)

    def forward(self, emb_indices):
        r"""Sinusoidal Positional Embedding.
        Args:
            emb_indices: torch.Tensor (*) 表示每个位置的索引 示例的维度为(2, 10) 2表示batch,10是seq_len
        Returns:
            embeddings: torch.Tensor (*, D)
        """
        input_shape = emb_indices.shape # 示例运行结果维度 (2,10)
        omegas = emb_indices.view(-1, 1, 1) * self.div_term.view(1, -1, 1)  # (-1, d_model/2, 1) 第0个维度是seq长度,第1个维度是d_model/2,与div_indices长度一致
        # 分别生成正弦和余弦嵌入
        sin_embeddings = torch.sin(omegas)
        cos_embeddings = torch.cos(omegas)
        # 将正弦和余弦位置嵌入组合成一个单一的嵌入向量,并确保该向量的形状与输入数据的形状兼容。这样做的好处是可以将位置信息有效地融入到
        # 模型的输入中,这对于诸如Transformer等模型来说是非常重要的。
        embeddings = torch.cat([sin_embeddings, cos_embeddings], dim=2)  # (-1, d_model/2, 2) 示例的输出维度为 (20, 4, 2)
        # 在 PyTorch中,* 常用于解包张量的形状(shape),以便将形状的各个维度作为独立的参数传递给函数。
        # view 函数用于改变张量(Tensor)的形状,它返回一个新的张量,该张量与原张量共享底层数据存储,只是对数据的布局进行了重新解释。
        # 即view不改变存储顺序,只改变对张量数据的索引方式。
        embeddings = embeddings.view(*input_shape, self.d_model)  # (*, d_model)  示例运行结果维度 (2, 10, 8)
        embeddings = embeddings.detach() # 将嵌入从计算图中分离,避免计算梯度。
        return embeddings

# PPFStructualEmbedding示例代码
# ppf = torch.randn(2, 10, 4)
# embedding_module = PPFStructualEmbedding(hidden_dim=128, mode='global')
# embeddings = embedding_module(ppf)
# 用于生成结构化嵌入(Structural Embedding),特别是针对点云数据中的 PPF(Point Pair Features)特征。
# PPF 是一种描述点对之间几何关系的特征,通常用于 3D 点云处理任务,如点云配准、分类和分割。
class PPFStructualEmbedding(nn.Module):
    def __init__(self, hidden_dim, mode='local'):
        super(PPFStructualEmbedding, self).__init__()
        if mode == 'local':
            self.embedding = SinusoidalPositionalEmbedding(hidden_dim) # hidden_dim表示位置编码的维度,参见SinusoidalPositionalEmbedding定义
            self.proj = nn.Linear(4, hidden_dim) # 线性投影层,用于将输入特征映射到目标维度。
        elif mode == 'global':
            self.embedding = SinusoidalPositionalEmbedding(hidden_dim // 4)
            self.proj = nn.Linear(hidden_dim, hidden_dim) # 线性投影层,用于将输入特征映射到目标维度。
        else:
            raise 'mode should be in [local, global]'
        self.mode = mode

    # 输入的ppf特征,形状为(batch_size, num_points, 4),其中每个点对有4个特征值(例如,距离和三个角度)。
    def forward(self, ppf):
        # 在 PPF(Point Pair Features)进行位置嵌入编码时,使用各个特征值(如距离和角度)而不是传统的“位置索引”(pos),
        # 主要是因为 PPF 特征描述的是点对之间的几何关系,而不是点在序列中的绝对位置。
        if self.mode == 'local':
            embeddings = self.proj(ppf)
        elif self.mode == 'global':
            # 对每个特征分别应用正弦位置编码模块self.embedding,生成嵌入。
            d_embeddings = self.embedding(ppf[..., 0])  # 距离特征。
            a_embeddings0 = self.embedding(ppf[..., 1]) # 角度特征1
            a_embeddings1 = self.embedding(ppf[..., 2]) # 角度特征2
            a_embeddings2 = self.embedding(ppf[..., 3]) # 角度特征3
            # 将生成的嵌入拼接起来,形成一个完整的嵌入向量。
            embeddings = torch.cat([d_embeddings, a_embeddings0, a_embeddings1, a_embeddings2], dim=-1)
            # 将拼接后的嵌入通过线性投影层 self.proj 映射到目标维度 hidden_dim。
            embeddings = self.proj(embeddings)
            # 使用L2归一化(F.normalize)将嵌入向量的范数归一化为1,以增强嵌入的稳定性和可比性。
            embeddings = F.normalize(embeddings, dim=-1, p=2)
        else:
            raise 'mode should be in [local, global]'

        return embeddings

# ppf = torch.randn(2, 10, 4)
# embedding_module = PPFStructualEmbedding(hidden_dim=128, mode='global')
# embeddings = embedding_module(ppf)

# 以下模块将点云中的几何结构嵌入到高维空间中。
# 该嵌入捕捉了点之间的成对距离和角度关系,这对于许多 3D 视觉任务(如点云分类、分割或目标检测)至关重要。
# 该类提供了一种强大的方法来编码点云中的几何信息,可以用作分类或分割等下游任务的输入。
# reduction_a: 用于聚合角度嵌入的缩减方法(max或mean).
class GeometricStructureEmbedding(nn.Module):
    def __init__(self, hidden_dim, sigma_d, sigma_a, angle_k, reduction_a='max'):
        super(GeometricStructureEmbedding, self).__init__()
        self.sigma_d = sigma_d # 距离嵌入的缩放因子
        self.sigma_a = sigma_a # 角度嵌入的缩放因子
        self.factor_a = 180.0 / (self.sigma_a * np.pi) 
        self.angle_k = angle_k # 用于角度嵌入的最近邻点的数量
        # 这里使用的是固定的k值,根据点密度动态选择邻居数量,而不是固定k,可以提高鲁棒性。

        self.embedding = SinusoidalPositionalEmbedding(hidden_dim)
        self.proj_d = nn.Linear(hidden_dim, hidden_dim)
        self.proj_a = nn.Linear(hidden_dim, hidden_dim)

        self.reduction_a = reduction_a
        if self.reduction_a not in ['max', 'mean']:
            raise ValueError(f'Unsupported reduction mode: {self.reduction_a}.')

    # 根据输入点云计算距离和角度嵌入的索引。
    @torch.no_grad()
    def get_embedding_indices(self, points):
        r"""Compute the indices of pair-wise distance embedding and triplet-wise angular embedding.
        Args:
            points: torch.Tensor (B, N, 3), input point cloud
        Returns:
            d_indices: torch.FloatTensor (B, N, N), distance embedding indices 包含成对距离索引
            a_indices: torch.FloatTensor (B, N, N, k), angular embedding indices 包含k个最近邻点的角度索引
        """
        batch_size, num_point, _ = points.shape
        # 距离嵌入:捕捉点之间的成对距离,对于理解点云的全局结构至关重要。
        dist_map = torch.sqrt(pairwise_distance(points, points))  # (B, N, N) 计算所有点之间的成对距离
        d_indices = dist_map / self.sigma_d # 通过sigma_d对距离进行归一化,得到d_indices

        # 角度嵌入:通过考虑相邻点之间的角度来捕捉局部几何关系
        k = self.angle_k
        knn_indices = dist_map.topk(k=k + 1, dim=2, largest=False)[1][:, :, 1:]  # (B, N, k) 使用topk找到每个点的k个最近邻点
        knn_indices = knn_indices.unsqueeze(3).expand(batch_size, num_point, k, 3)  # (B, N, k, 3)
        expanded_points = points.unsqueeze(1).expand(batch_size, num_point, num_point, 3)  # (B, N, N, 3)
        knn_points = torch.gather(expanded_points, dim=2, index=knn_indices)  # (B, N, k, 3)

        ref_vectors = knn_points - points.unsqueeze(2)  # (B, N, k, 3) 计算每个点与其邻居之间的向量(ref_vectors)
        anc_vectors = points.unsqueeze(1) - points.unsqueeze(2)  # (B, N, N, 3) 计算所有点对之间的向量(anc_vectors)。
        ref_vectors = ref_vectors.unsqueeze(2).expand(batch_size, num_point, num_point, k, 3)  # (B, N, N, k, 3)
        anc_vectors = anc_vectors.unsqueeze(3).expand(batch_size, num_point, num_point, k, 3)  # (B, N, N, k, 3)
        # 使用叉积和点积计算这些向量之间的正弦和余弦值
        sin_values = torch.linalg.norm(torch.cross(ref_vectors, anc_vectors, dim=-1), dim=-1)  # (B, N, N, k)
        cos_values = torch.sum(ref_vectors * anc_vectors, dim=-1)  # (B, N, N, k)
        # 使用atan2计算角度,并通过self.fafctor_a对其进行缩放,得到a_indices
        angles = torch.atan2(sin_values, cos_values)  # (B, N, N, k)
        a_indices = angles * self.factor_a

        return d_indices, a_indices

    # 通过结合距离和角度嵌入来计算最终的几何结构嵌入。
    def forward(self, points): 
        d_indices, a_indices = self.get_embedding_indices(points)
        # d_indices (B, N, N)
        # a_indices (B, N, N, k)

        d_embeddings = self.embedding(d_indices) # 嵌入距离索引 
        d_embeddings = self.proj_d(d_embeddings) # 使用proj_d对其进行投影

        a_embeddings = self.embedding(a_indices) # 嵌入角度索引
        a_embeddings = self.proj_a(a_embeddings) # 使用proj_a对其进行投影
        # 使用 max 或 mean 缩减方法跨 k 维度聚合角度嵌入。使模型能够专注于最显著或平均的角度特征。
        if self.reduction_a == 'max':
            a_embeddings = a_embeddings.max(dim=3)[0]
        else:
            a_embeddings = a_embeddings.mean(dim=3)
        # 通过求和将距离和角度嵌入结合起来。
        embeddings = d_embeddings + a_embeddings

        return embeddings

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值