最远点采样-球查询-采样和分组-代码详解
专栏持续更新中!关注博主查看后续部分!
最远点采样、球查询等位于 pointnet2_utils.py 定义
点云坐标归一化
点云坐标归一化是一种预处理步骤,用于将点云数据标准化到一个统一的尺度,通常是在一个特定的范围内,比如 [-1, 1] 或 [0, 1]。这一步骤对于很多三维数据处理和分析任务来说是很重要的,比如三维重建、物体识别、点云分类等。归一化可以帮助改善算法的性能,因为它消除了数据在尺度上的差异,让算法能够更专注于数据的结构和形状特征。
# 归一化点云,以centroid为中心,球半径为1进行归一化 (将点云数据中心化并缩放至单位球内)
# pc 表示输入点云

欧几里得距离计算
# square_distance 用于在 ball query 过程中确定每一个点距离形心的距离
# 函数的输入时两组点, N 为 src 的个数, M 为 dst 的个数, C 为输入点的通道数(特征数)
# 函数的返回是两组点的欧几里得距离,即 N×M
# src: source points, [B, N, C]
# dst: target points, [B, M, C]

对应索引的坐标查询
输入点云集和索引,查询对应的坐标。
# 按照输入点云数据和索引返回索引的点云数据
# 例如 points 为 2×10×3 小批量点云集, idx 为 [[1, 2, 10], [5, 3, 2]] (2×3×3)
# 则返回 Batch1 中第 1, 2, 10 和 Batch2 中第 5, 3, 2 组成的 2×3×3 的点云集

最远点采样
最远点采样(Farthest Point Sampling, FPS)是一种在点云数据中进行子采样的方法,常用于三维计算机视觉和图形处理中。这种方法的目标是从原始点云中选择一组代表性的点,这组点能够尽可能覆盖原始点云的整个形状。最远点采样特别适用于那些需要减少计算量和内存消耗,同时尽量保留几何信息的场景。
# FPS 最远点采样
# 输入: xyz 表示点云数据集合 (B×N×3), npoint 表示采样点的个数(对每个batch来说)
# 输出: centroids (B×npoint) 即每个 batch 采样点的索引

球查询
对给定半径和形心的点云集,查询给定半径的领域点索引。
# 用于寻找球形邻域内的点 (以 centroids 作为中心点, 查找局部区域中的点)
# 输入: radius: 邻域半径; nsample: 局部邻域中的最大采样点数; xyz: 点集; new_xyz: centroids对应的点集 (即降采样后的点集 - 形心坐标)
# 输出为每个 centroids 球形邻域内 nsample 个采样点集的索引 B×S×nsample

采样和分组(局部)
提取局部特征时使用如下函数
# 采样分组 (最远点采样 和 球查询 相结合) 将整个点云分为 npoint 个分散的局部 group, 并查询邻域内对应的坐标
# sample_and_group 和 sample_and_group_all 的区别在于 sample_and_group_all 直接将全部点作为一个group (提取全局特征)
# npoint: 形心点的个数(相比于全局, 降采样的个数); (S)
# radius: 球查询半径大小;
# nsample: 球查询的点数(邻域内, 若不满足 nsample, 超过半径平方的点将被最近的点复制代替)
# xyz: 点的坐标数据; (B × N × 3)
# points: 点的数据 (可能包含法向量等其他特征); (B × N × D)
# returnfps: 是否返回 形心索引 和 各子邻域内点的坐标

采样和分组(全局)
提取全局特征时使用如下函数
# 将所有点作为一个 group, 可以理解为 npoint=1

整体代码如下所示:
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np
def timeit(tag, t):
print("{}: {}s".format(tag, time() - t))
return time()
# 归一化点云,以centroid为中心,球半径为1进行归一化 (将点云数据中心化并缩放至单位球内)
# pc 表示输入点云
def pc_normalize(pc):
l = pc.shape[0]
# 计算点云中的形心
centroid = np.mean(pc, axis=0)
# 中心化点云
pc = pc - centroid
# 缩放因子 m = max(sqrt(sum(pi^2)))
# 在 残差MLP中,集合仿射模块与其略微不同 (求平均并添加了一个用于保持分母不为0的值)
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
# square_distance 用于在 ball query 过程中确定每一个点距离形心的距离
# 函数的输入时两组点, N 为 src 的个数, M 为 dst 的个数, C 为输入点的通道数(特征数)
# 函数的返回是两组点的欧几里得距离,即 N×M
# src: source points, [B, N, C]
# dst: target points, [B, M, C]
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) # .permute() 方法用于重新排列张量的维度 tensor.permute(*dims)
# permute() 方法不会修改原始张量, 而是返回一个新张量
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
# 按照输入点云数据和索引返回索引的点云数据
# 例如 points 为 2×10×3 小批量点云集, idx 为 [[1, 2, 10], [5, 3, 2]] (2×3×3)
# 则返回 Batch1 中第 1, 2, 10 和 Batch2 中第 5, 3, 2 组成的 2×3×3 的点云集
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device # 获得张量所在的设备
B = points.shape[0] # 批量大小
# 初始化 view_shape - 用于创建批量索引张量的视图
view_shape = list(idx.shape) # 第一个值为 idx 的第一维维度值, 其他值设置为1
view_shape[1:] = [1] * (len(view_shape) - 1)
# 初始化 repeat_shape - 用于指定如何重复批量索引张量
repeat_shape = list(idx.shape) # 第一个值设置为1, 其他值为 idx 除第一维的维度
repeat_shape[0] = 1
# 创建批量索引
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
# 首先, 创建一个 0-B-1 的一维张量, 表示每个批次的索引. 然后, 转移到对应的设备上, 并通过 view 调整形状 (batchSize×1×1).
# 最后, 使用 repeat 以 (1×S) 重复该张量, S表示对每个 batch 查询的点数
# 具体来讲: 先创建 [0,1,...,B-1] 的 batch 索引,然后将其调正为 B×1, 并沿着第二维重复每个batch需要查询点个数的次数(S)
# batch_indices 为 B × S 每列为 [0,1,...,B-1] 的数组
new_points = points[batch_indices, idx, :] # 提取每个 batch 需要查询的点对应的坐标
return new_points
# FPS 最远点采样
# 输入: xyz 表示点云数据集合 (B×N×3), npoint 表示采样点的个数(对每个batch来说)
# 输出: centroids (B×npoint) 即每个 batch 采样点的索引
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
# 初始化 centroids 矩阵,用于存储每个 batch 的 npoint 个采样点的索引值
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) # B×npoint
# 初始化 distance 矩阵, 用于记录某个 batch 中某个点到所有点欧几里得距离和的最小值
distance = torch.ones(B, N).to(device) * 1e10 # B×N
# 当前的最远点 - 第一次随机生成 B 个 0-N 之间的随机索引
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) # B
# batch 的索引数组 0-B-1
batch_indices = torch.arange(B, dtype=torch.long).to(device) # B
# 循环采样 xyz, 每次采样 B 个点, 对应每个 batch 的点索引
for i in range(npoint):
# 当前采样点 centroids 为最远点
centroids[:, i] = farthest
# farthest 大小为B, 值的范围为 0-N-1 (与 batch_indices 尺寸相同, 得到 B 次查找的点的坐标)
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) # B×1×3
# 求每个 batch 所有点和形心欧几里得距离之和
dist = torch.sum((xyz - centroid) ** 2, -1) # B×N
# 更新迭代
mask = dist < distance
distance[mask] = dist[mask] # 记录每个 batch 中形心(已采样的点)和其他所有点欧几里得距离和的最小值
farthest = torch.max(distance, -1)[1] # 和当前形心距离最远的点作为下一个被采样点
return centroids
# 用于寻找球形邻域内的点 (以 centroids 作为中心点, 查找局部区域中的点)
# 输入: radius: 邻域半径; nsample: 局部邻域中的最大采样点数; xyz: 点集; new_xyz: centroids对应的点集 (即降采样后的点集 - 形心坐标)
# 输出为每个 centroids 球形邻域内 nsample 个采样点集的索引 B×S×nsample
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
# 初始化 group_idx - 每个点对应的索引
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) # B×S×N
# 计算 new_xyz 中每个点和 xyz 中每个点的欧几里得距离 (形心和其他点的欧几里得距离)
sqrdists = square_distance(new_xyz, xyz) # B×S×N
# 将和形心大于半径平方的点对应的欧几里得距离值置为 N (即该batch内点的总数, 也可以理解为序列索引的最大值)
group_idx[sqrdists > radius ** 2] = N # B×S×N
# 升序排列 (默认), 降序是设置参数 descending=True 来实现
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] # 取升序排列的前 nsample 个 (B×S×nsample)
# 考虑到前 nsample 点中可能有超过半径平方的点. 为此, 前 nsample 点中超过半径平方的点替换为该领域内升序排列的第一个点
# B×S×nsample 其中, nsample 维都为该形心邻域内最近点的索引
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx # B×S×nsample
# 采样分组 (最远点采样 和 球查询 相结合) 将整个点云分为 npoint 个分散的局部 group, 并查询邻域内对应的坐标
# sample_and_group 和 sample_and_group_all 的区别在于 sample_and_group_all 直接将全部点作为一个group (提取全局特征)
# npoint: 形心点的个数(相比于全局, 降采样的个数); (S)
# radius: 球查询半径大小;
# nsample: 球查询的点数(邻域内, 若不满足 nsample, 超过半径平方的点将被最近的点复制代替)
# xyz: 点的坐标数据; (B × N × 3)
# points: 点的数据 (可能包含法向量等其他特征); (B × N × D)
# returnfps: 是否返回 形心索引 和 各子邻域内点的坐标
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
"""
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
B, N, C = xyz.shape # C=3
S = npoint # 形心点个数 (邻域的个数)
# 最远点采样 - 返回每个 batch 形心对应的索引
fps_idx = farthest_point_sample(xyz, npoint) # B×S
# 查询 fps_idx 索引对应的坐标, 返回形心对应的坐标(new_xyz)
new_xyz = index_points(xyz, fps_idx) # B×S×C
# 球查询, 返回形心对应邻域内点的索引 idx
idx = query_ball_point(radius, nsample, xyz, new_xyz) # B×S×nsample
# 查询 idx 索引对应的坐标(每个形心邻域内 nsample 个点对应的坐标)
grouped_xyz = index_points(xyz, idx) # B×S×nsample×C
# 邻域内坐标正则化
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) # B×S×nsample×C
if points is not None:
# 当点存在除坐标外的其他特征时, 提取各子邻域内点的其他特征
grouped_points = index_points(points, idx) # B×S×nsample×D
# 沿最后一维拼接为整体邻域特征
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # B×S×nsample×(C+D)
else:
new_points = grouped_xyz_norm
# 是否返回 形心索引 和 各子邻域内点的坐标
if returnfps:
# B×S×C; B×S×nsample×(3+D) 或 B×S×nsample×3; B×S×nsample×C; B×S
return new_xyz, new_points, grouped_xyz, fps_idx
else:
# B×S×C; B×S×nsample×(3+D) 或 B×S×nsample×3;
return new_xyz, new_points
# 将所有点作为一个 group, 可以理解为 npoint=1
def sample_and_group_all(xyz, points):
"""
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
"""
device = xyz.device
B, N, C = xyz.shape # C=3
new_xyz = torch.zeros(B, 1, C).to(device)
grouped_xyz = xyz.view(B, 1, N, C)
# 拼接位置信息外的特征
if points is not None:
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) # B × 1 × N × (3+D)
else:
new_points = grouped_xyz # B × 1 × N × 3
return new_xyz, new_points
本文详细讲解了PointNet++中的关键步骤——最远点采样(FPS)、球查询、采样和分组的代码实现,包括点云的归一化、欧几里得距离计算以及对应索引的坐标查询。这些操作对于点云处理和三维计算机视觉任务至关重要。


【最远点采样-球查询-采样和分组 代码详解】&spm=1001.2101.3001.5002&articleId=136980292&d=1&t=3&u=054f6ebb6ff34d2fad9ca2e3292c3fa5)
1470

被折叠的 条评论
为什么被折叠?



