目录
matrix_to_quaternion 注释,输入 N*3*3,输出是N*4
代码来源
matrix_to_quaternion 注释,输入 N*3*3,输出是N*4
from typing import Optional, Union
import torch
import torch.nn.functional as F
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
# 检查输入矩阵是否为 3x3 的旋转矩阵
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
# 获取矩阵的批处理维度
batch_dim = matrix.shape[:-2]
# 展开矩阵并获取各个元素
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
# 计算四元数的绝对值部分(即四元数的平方根的正部分)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
# 通过计算得到四元数的各个分量(r, i, j, k)
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
# 设置一个阈值,避免在数值上出现问题
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
# 计算候选四元数(候选的四元数是通过 q_abs 进行归一化的)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# 选择最佳的候选四元数,这个候选四元数的分母最大(即数值最稳定)
return quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
].reshape(batch_dim + (4,))
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to axis/angle.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
# 计算四元数的规范化部分,即虚部的模长
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
# 计算半角(角度的一半)
half_angles = torch.atan2(norms, quaternions[..., :1])
# 计算实际角度
angles = 2 * half_angles
# 设置一个非常小的阈值,用于处理小角度的情况
eps = 1e-6
small_angles = angles.abs() < eps
# 创建一个空的张量来存储 sin(半角) / 角度 的值
sin_half_angles_over_angles = torch.empty_like(angles)
# 对于较大的角度,直接使用公式计算 sin(半角) / 角度
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# 对于非常小的角度,采用更精确的近似值
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
# 返回单位轴向量
return quaternions[..., 1:] / sin_half_angles_over_angles
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to axis/angle.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
# 将矩阵转换为四元数,然后将四元数转换为轴角表示
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
向量和矩阵互转
旋转向量转旋转矩阵,输入N*3
2. rotvec2mat(rotvec)
详细说明:
-
输入一个旋转向量(轴角表示),输出对应的旋转矩阵。
-
旋转向量
rotvec
是由旋转轴和旋转角度组成的。 -
使用 Rodrigues' 旋转公式来计算旋转矩阵。
def rotvec2mat(rotvec):
angle = torch.linalg.norm(rotvec, dim=-1, keepdim=True) # 计算旋转角度(向量的模)
axis = torch.nan_to_num(rotvec / angle) # 归一化旋转轴(防止 NaN)
# 使用 Rodrigues' 公式计算旋转矩阵
sin_axis = torch.sin(angle) * axis
cos_angle = torch.cos(angle)
cos1_axis = (1.0 - cos_angle) * axis
_, axis_y, axis_z = torch.unbind(axis, dim=-1)
cos1_axis_x, cos1_axis_y, _ = torch.unbind(cos1_axis, dim=-1)
sin_axis_x, sin_axis_y, sin_axis_z = torch.unbind(sin_axis, dim=-1)
tmp = cos1_axis_x * axis_y
m01 = tmp - sin_axis_z
m10 = tmp + sin_axis_z
tmp = cos1_axis_x * axis_z
m02 = tmp + sin_axis_y
m20 = tmp - sin_axis_y
tmp = cos1_axis_y * axis_z
m12 = tmp - sin_axis_x
m21 = tmp + sin_axis_x
diag = cos1_axis * axis + cos_angle
m00, m11, m22 = torch.unbind(diag, dim=-1)
# 组合成旋转矩阵
matrix = torch.stack((m00, m01, m02, m10, m11, m12, m20, m21, m22), dim=-1)
return torch.unflatten(matrix, -1, (3, 3)) # 转换为 3x3 的矩阵
旋转矩阵转旋转向量
-
输入一个旋转矩阵
rotmat
,输出一个旋转向量,表示旋转轴和旋转角度。 -
通过从旋转矩阵中提取信息来计算旋转向量(使用矩阵的迹、特征值等方法)。
def mat2rotvec(rotmat):
rows = torch.unbind(rotmat, dim=-2) # 将矩阵拆解为行
r = [torch.unbind(row, dim=-1) for row in rows] # 将每行再拆解为列
# 计算旋转矩阵的参数
p10p01 = r[1][0] + r[0][1]
p10m01 = r[1][0] - r[0][1]
p02p20 = r[0][2] + r[2][0]
p02m20 = r[0][2] - r[2][0]
p21p12 = r[2][1] + r[1][2]
p21m12 = r[2][1] - r[1][2]
p00p11 = r[0][0] + r[1][1]
p00m11 = r[0][0] - r[1][1]
_1p22 = 1.0 + r[2][2]
_1m22 = 1.0 - r[2][2]
trace = torch.diagonal(rotmat, dim1=-2, dim2=-1).sum(-1) # 计算矩阵的迹
cond0 = torch.stack((p21m12, p02m20, p10m01, 1.0 + trace), dim=-1)
cond1 = torch.stack((_1m22 + p00m11, p10p01, p02p20, p21m12), dim=-1)
cond2 = torch.stack((p10p01, _1m22 - p00m11, p21p12, p02p20), dim=-1)
cond3 = torch.stack((p02p20, p21p12, _1p22 - p00p11, p10m01), dim=-1)
# 判断旋转矩阵的迹,选择正确的条件
trace_pos = (trace > 0).unsqueeze(-1)
d00_large = torch.logical_and(r[0][0] > r[1][1], r[0][0] > r[2][2]).unsqueeze(-1)
d11_large = (r[1][1] > r[2][2]).unsqueeze(-1)
q = torch.where(trace_pos, cond0, torch.where(d00_large, cond1, torch.where(d11_large, cond2, cond3)))
xyz, w = torch.split(q, (3, 1), dim=-1) # 分离旋转轴和角度
norm = torch.linalg.norm(xyz, dim=-1, keepdim=True) # 计算旋转轴的模长
return (torch.nan_to_num(2.0 / norm) * torch.atan2(norm, w)) * xyz # 返回旋转向量