FPS代码,原理,笔记,计算流程

一.代码

import torch

def farthest_point_sample(point_cloud, num_samples):
    """
    对点云进行FPS采样。
    
    参数:
        point_cloud (torch.Tensor): 输入点云,形状为 [B, N, 3],B是批次大小,N是点云中点的数量,3是每个点的坐标。
        num_samples (int): 需要采样的点的数量。
    
    返回:
        sampled_indices (torch.Tensor): 采样点的索引,形状为 [B, num_samples]。
    """
    B, N, C = point_cloud.shape
    device = point_cloud.device

    # 初始化采样点的索引
    sampled_indices = torch.zeros(B, num_samples, dtype=torch.long, device=device)

    # 初始化距离矩阵,记录每个点到已选点的最小距离
    distance = torch.ones(B, N, device=device) * 1e10

    # 随机选择第一个点
    farthest = torch.randint(0, N, (B,), dtype=torch.long, device=device)
    batch_indices = torch.arange(B, dtype=torch.long, device=device)
    #batch_indices表示批次中每个点的索引,【0,B-1】

    for i in range(num_samples):
        # 将当前最远点加入采样结果
        sampled_indices[:, i] = farthest

        # 获取当前最远点的坐标
        centroid = point_cloud[batch_indices, farthest, :].view(B, 1, 3)

        # 计算所有点到当前最远点的距离
        dist = torch.sum((point_cloud - centroid) ** 2, dim=-1)

        # 更新每个点的最小距离
        mask = dist < distance
        distance[mask] = dist[mask]

        # 选择距离最远的点作为下一个点
        farthest = torch.argmax(distance, dim=-1)

    return sampled_indices

二.原理

FPS采样的原理

  1. 初始化

    随机选择一个点作为第一个采样点。
  2. 迭代过程

        在每次迭代中,计算每个未选点到已选点集的最小距离。

        选择具有最大最小距离的点作为下一个采样点。

     3. 终止条件:        

        重复上述步骤,直到达到预设的采样点数。

具体步骤

  1. 随机选择第一个点

    从点云中随机选取一个点 p0p0​,加入采样点集 SS。
  2. 计算距离

    对于每个未选点 pipi​,计算其到 SS 中所有点的最小距离 didi​。
  3. 选择最远点

    选取 didi​ 最大的点 pjpj​,加入 SS。
  4. 重复

    重复步骤2和3,直到 SS 的大小达到目标采样数。

三.笔记

  • torch.long是 PyTorch 中的一种数据类型,表示 64 位整数(64-bit integer)。它通常用于存储整数值,尤其是在需要较大整数范围时
  • distance矩阵使用torch.ones是为了让所有点成为候选点
  • torch.randit:随即函数,用于生成一个随机torch张量。(0,N)是随机取值范围。(B,)为张量形状,一维长度B

四.计算流程

假设我们有一个点云,包含6个点,坐标如下:

point_cloud = [
    [0.0, 0.0, 0.0],  # 点0
    [1.0, 0.0, 0.0],  # 点1
    [0.0, 1.0, 0.0],  # 点2
    [1.0, 1.0, 0.0],  # 点3
    [0.5, 0.5, 0.0],  # 点4
    [0.7, 0.7, 0.0]   # 点5
]

  • 点云形状:[1, 6, 3](1个点云,6个点,每个点3个坐标)。

  • 目标采样点数:num_samples = 3(选择3个点)。


代码执行过程

以下是FPS算法的代码,我们将逐行执行并解释。

初始化
  1. 输入数据

    • point_cloud[1, 6, 3]

    • num_samples = 3

  2. 初始化变量

    • sampled_indices[1, 3],初始值为[[0, 0, 0]]

    • distance[1, 6],初始值为[1e10, 1e10, 1e10, 1e10, 1e10, 1e10]

    • farthest:随机选择一个起始点,假设选择点0,即farthest = [0]


第一次迭代(i = 0)
  1. sampled_indices[:, 0] = farthest

    • 将当前最远点(点0)的索引存储到 sampled_indices 的第0列。

    • sampled_indices 更新为 [[0, 0, 0]]

  2. centroid = point_cloud[batch_indices, farthest, :].view(1, 1, 3)

    • 提取点0的坐标 [0.0, 0.0, 0.0],并将其形状调整为 [1, 1, 3]

    • centroid[[[0.0, 0.0, 0.0]]]

  3. dist = torch.sum((point_cloud - centroid) ** 2, dim=-1)

    • 计算所有点到点0的距离的平方。

    • 结果 dist[[0.0, 1.0, 1.0, 2.0, 0.5, 0.98]](分别是点0、点1、点2、点3、点4、点5到点0的距离平方)。

  4. mask = dist < distance

    • 比较 distdistance,生成布尔掩码。

    • mask[[True, True, True, True, True, True]](因为初始 distance1e10,所有点都满足条件)。

  5. distance[mask] = dist[mask]

    • 更新 distance[[0.0, 1.0, 1.0, 2.0, 0.5, 0.98]]

  6. farthest = torch.argmax(distance, dim=-1)

    • 找到距离最远的点。distance 中最大值是 2.0,对应的点是点3。

    • farthest 更新为 [3]


第二次迭代(i = 1)
  1. sampled_indices[:, 1] = farthest

    • 将当前最远点(点3)的索引存储到 sampled_indices 的第1列。

    • sampled_indices 更新为 [[0, 3, 0]]

  2. centroid = point_cloud[batch_indices, farthest, :].view(1, 1, 3)

    • 提取点3的坐标 [1.0, 1.0, 0.0],并将其形状调整为 [1, 1, 3]

    • centroid[[[1.0, 1.0, 0.0]]]

  3. dist = torch.sum((point_cloud - centroid) ** 2, dim=-1)

    • 计算所有点到点3的距离的平方。

    • 结果 dist[[2.0, 1.0, 1.0, 0.0, 0.5, 0.18]](分别是点0、点1、点2、点3、点4、点5到点3的距离平方)。

  4. mask = dist < distance

    • 比较 distdistance,生成布尔掩码。

    • mask[[False, True, True, True, True, True]](点0的距离 2.0 不小于当前 distance 中的 0.0)。

  5. distance[mask] = dist[mask]

    • 更新 distance[[0.0, 1.0, 1.0, 0.0, 0.5, 0.18]]

  6. farthest = torch.argmax(distance, dim=-1)

    • 找到距离最远的点。distance 中最大值是 1.0,对应的点是点1和点2。

    • 选择其中一个点(如点1),farthest 更新为 [1]


第三次迭代(i = 2)
  1. sampled_indices[:, 2] = farthest

    • 将当前最远点(点1)的索引存储到 sampled_indices 的第2列。

    • sampled_indices 更新为 [[0, 3, 1]]

  2. centroid = point_cloud[batch_indices, farthest, :].view(1, 1, 3)

    • 提取点1的坐标 [1.0, 0.0, 0.0],并将其形状调整为 [1, 1, 3]

    • centroid[[[1.0, 0.0, 0.0]]]

  3. dist = torch.sum((point_cloud - centroid) ** 2, dim=-1)

    • 计算所有点到点1的距离的平方。

    • 结果 dist[[1.0, 0.0, 2.0, 1.0, 0.25, 0.49]](分别是点0、点1、点2、点3、点4、点5到点1的距离平方)。

  4. mask = dist < distance

    • 比较 distdistance,生成布尔掩码。

    • mask[[False, True, False, False, True, True]](点0、点2、点3的距离不小于当前 distance 中的值)。

  5. distance[mask] = dist[mask]

    • 更新 distance[[0.0, 0.0, 1.0, 0.0, 0.25, 0.18]]

  6. farthest = torch.argmax(distance, dim=-1)

    • 找到距离最远的点。distance 中最大值是 1.0,对应的点是点2。

    • farthest 更新为 [2]


最终结果
  • sampled_indices[[0, 3, 1]],表示选择了点0、点3和点1。

  • 采样后的点云为:

sampled_point_cloud = [
    [0.0, 0.0, 0.0],  # 点0
    [1.0, 1.0, 0.0],  # 点3
    [1.0, 0.0, 0.0]   # 点1
]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值