一.代码
import torch
def farthest_point_sample(point_cloud, num_samples):
"""
对点云进行FPS采样。
参数:
point_cloud (torch.Tensor): 输入点云,形状为 [B, N, 3],B是批次大小,N是点云中点的数量,3是每个点的坐标。
num_samples (int): 需要采样的点的数量。
返回:
sampled_indices (torch.Tensor): 采样点的索引,形状为 [B, num_samples]。
"""
B, N, C = point_cloud.shape
device = point_cloud.device
# 初始化采样点的索引
sampled_indices = torch.zeros(B, num_samples, dtype=torch.long, device=device)
# 初始化距离矩阵,记录每个点到已选点的最小距离
distance = torch.ones(B, N, device=device) * 1e10
# 随机选择第一个点
farthest = torch.randint(0, N, (B,), dtype=torch.long, device=device)
batch_indices = torch.arange(B, dtype=torch.long, device=device)
#batch_indices表示批次中每个点的索引,【0,B-1】
for i in range(num_samples):
# 将当前最远点加入采样结果
sampled_indices[:, i] = farthest
# 获取当前最远点的坐标
centroid = point_cloud[batch_indices, farthest, :].view(B, 1, 3)
# 计算所有点到当前最远点的距离
dist = torch.sum((point_cloud - centroid) ** 2, dim=-1)
# 更新每个点的最小距离
mask = dist < distance
distance[mask] = dist[mask]
# 选择距离最远的点作为下一个点
farthest = torch.argmax(distance, dim=-1)
return sampled_indices
二.原理
FPS采样的原理
-
初始化:
随机选择一个点作为第一个采样点。 -
迭代过程:
在每次迭代中,计算每个未选点到已选点集的最小距离。
选择具有最大最小距离的点作为下一个采样点。
3. 终止条件:
重复上述步骤,直到达到预设的采样点数。
具体步骤
-
随机选择第一个点:
从点云中随机选取一个点 p0p0,加入采样点集 SS。 -
计算距离:
对于每个未选点 pipi,计算其到 SS 中所有点的最小距离 didi。 -
选择最远点:
选取 didi 最大的点 pjpj,加入 SS。 -
重复:
重复步骤2和3,直到 SS 的大小达到目标采样数。
三.笔记
- torch.long是 PyTorch 中的一种数据类型,表示 64 位整数(64-bit integer)。它通常用于存储整数值,尤其是在需要较大整数范围时
- distance矩阵使用torch.ones是为了让所有点成为候选点
- torch.randit:随即函数,用于生成一个随机torch张量。(0,N)是随机取值范围。(B,)为张量形状,一维长度B
四.计算流程
假设我们有一个点云,包含6个点,坐标如下:
point_cloud = [ [0.0, 0.0, 0.0], # 点0 [1.0, 0.0, 0.0], # 点1 [0.0, 1.0, 0.0], # 点2 [1.0, 1.0, 0.0], # 点3 [0.5, 0.5, 0.0], # 点4 [0.7, 0.7, 0.0] # 点5 ]
-
点云形状:
[1, 6, 3]
(1个点云,6个点,每个点3个坐标)。 -
目标采样点数:
num_samples = 3
(选择3个点)。
代码执行过程
以下是FPS算法的代码,我们将逐行执行并解释。
初始化
-
输入数据:
-
point_cloud
:[1, 6, 3]
。 -
num_samples = 3
。
-
-
初始化变量:
-
sampled_indices
:[1, 3]
,初始值为[[0, 0, 0]]
。 -
distance
:[1, 6]
,初始值为[1e10, 1e10, 1e10, 1e10, 1e10, 1e10]
。 -
farthest
:随机选择一个起始点,假设选择点0,即farthest = [0]
。
-
第一次迭代(i = 0)
-
sampled_indices[:, 0] = farthest
:-
将当前最远点(点0)的索引存储到
sampled_indices
的第0列。 -
sampled_indices
更新为[[0, 0, 0]]
。
-
-
centroid = point_cloud[batch_indices, farthest, :].view(1, 1, 3)
:-
提取点0的坐标
[0.0, 0.0, 0.0]
,并将其形状调整为[1, 1, 3]
。 -
centroid
为[[[0.0, 0.0, 0.0]]]
。
-
-
dist = torch.sum((point_cloud - centroid) ** 2, dim=-1)
:-
计算所有点到点0的距离的平方。
-
结果
dist
为[[0.0, 1.0, 1.0, 2.0, 0.5, 0.98]]
(分别是点0、点1、点2、点3、点4、点5到点0的距离平方)。
-
-
mask = dist < distance
:-
比较
dist
和distance
,生成布尔掩码。 -
mask
为[[True, True, True, True, True, True]]
(因为初始distance
是1e10
,所有点都满足条件)。
-
-
distance[mask] = dist[mask]
:-
更新
distance
为[[0.0, 1.0, 1.0, 2.0, 0.5, 0.98]]
。
-
-
farthest = torch.argmax(distance, dim=-1)
:-
找到距离最远的点。
distance
中最大值是2.0
,对应的点是点3。 -
farthest
更新为[3]
。
-
第二次迭代(i = 1)
-
sampled_indices[:, 1] = farthest
:-
将当前最远点(点3)的索引存储到
sampled_indices
的第1列。 -
sampled_indices
更新为[[0, 3, 0]]
。
-
-
centroid = point_cloud[batch_indices, farthest, :].view(1, 1, 3)
:-
提取点3的坐标
[1.0, 1.0, 0.0]
,并将其形状调整为[1, 1, 3]
。 -
centroid
为[[[1.0, 1.0, 0.0]]]
。
-
-
dist = torch.sum((point_cloud - centroid) ** 2, dim=-1)
:-
计算所有点到点3的距离的平方。
-
结果
dist
为[[2.0, 1.0, 1.0, 0.0, 0.5, 0.18]]
(分别是点0、点1、点2、点3、点4、点5到点3的距离平方)。
-
-
mask = dist < distance
:-
比较
dist
和distance
,生成布尔掩码。 -
mask
为[[False, True, True, True, True, True]]
(点0的距离2.0
不小于当前distance
中的0.0
)。
-
-
distance[mask] = dist[mask]
:-
更新
distance
为[[0.0, 1.0, 1.0, 0.0, 0.5, 0.18]]
。
-
-
farthest = torch.argmax(distance, dim=-1)
:-
找到距离最远的点。
distance
中最大值是1.0
,对应的点是点1和点2。 -
选择其中一个点(如点1),
farthest
更新为[1]
。
-
第三次迭代(i = 2)
-
sampled_indices[:, 2] = farthest
:-
将当前最远点(点1)的索引存储到
sampled_indices
的第2列。 -
sampled_indices
更新为[[0, 3, 1]]
。
-
-
centroid = point_cloud[batch_indices, farthest, :].view(1, 1, 3)
:-
提取点1的坐标
[1.0, 0.0, 0.0]
,并将其形状调整为[1, 1, 3]
。 -
centroid
为[[[1.0, 0.0, 0.0]]]
。
-
-
dist = torch.sum((point_cloud - centroid) ** 2, dim=-1)
:-
计算所有点到点1的距离的平方。
-
结果
dist
为[[1.0, 0.0, 2.0, 1.0, 0.25, 0.49]]
(分别是点0、点1、点2、点3、点4、点5到点1的距离平方)。
-
-
mask = dist < distance
:-
比较
dist
和distance
,生成布尔掩码。 -
mask
为[[False, True, False, False, True, True]]
(点0、点2、点3的距离不小于当前distance
中的值)。
-
-
distance[mask] = dist[mask]
:-
更新
distance
为[[0.0, 0.0, 1.0, 0.0, 0.25, 0.18]]
。
-
-
farthest = torch.argmax(distance, dim=-1)
:-
找到距离最远的点。
distance
中最大值是1.0
,对应的点是点2。 -
farthest
更新为[2]
。
-
最终结果
-
sampled_indices
为[[0, 3, 1]]
,表示选择了点0、点3和点1。 -
采样后的点云为:
sampled_point_cloud = [ [0.0, 0.0, 0.0], # 点0 [1.0, 1.0, 0.0], # 点3 [1.0, 0.0, 0.0] # 点1 ]