旋转矩阵转轴角 向量和矩阵互转

目录

代码来源

matrix_to_quaternion 注释,输入 N*3*3,输出是N*4

向量和矩阵互转

旋转向量转旋转矩阵,输入N*3

2. rotvec2mat(rotvec)

旋转矩阵转旋转向量


代码来源

GitHub - yohanshin/WHAM

matrix_to_quaternion 注释,输入 N*3*3,输出是N*4

from typing import Optional, Union
import torch
import torch.nn.functional as F

def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
    """
    Convert rotations given as rotation matrices to quaternions.

    Args:
        matrix: Rotation matrices as tensor of shape (..., 3, 3).

    Returns:
        quaternions with real part first, as tensor of shape (..., 4).
    """
    # 检查输入矩阵是否为 3x3 的旋转矩阵
    if matrix.size(-1) != 3 or matrix.size(-2) != 3:
        raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")

    # 获取矩阵的批处理维度
    batch_dim = matrix.shape[:-2]
    
    # 展开矩阵并获取各个元素
    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
        matrix.reshape(batch_dim + (9,)), dim=-1
    )

    # 计算四元数的绝对值部分(即四元数的平方根的正部分)
    q_abs = _sqrt_positive_part(
        torch.stack(
            [
                1.0 + m00 + m11 + m22,
                1.0 + m00 - m11 - m22,
                1.0 - m00 + m11 - m22,
                1.0 - m00 - m11 + m22,
            ],
            dim=-1,
        )
    )

    # 通过计算得到四元数的各个分量(r, i, j, k)
    quat_by_rijk = torch.stack(
        [
            torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
            torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
            torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
            torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
        ],
        dim=-2,
    )

    # 设置一个阈值,避免在数值上出现问题
    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
    
    # 计算候选四元数(候选的四元数是通过 q_abs 进行归一化的)
    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))

    # 选择最佳的候选四元数,这个候选四元数的分母最大(即数值最稳定)
    return quat_candidates[
        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
    ].reshape(batch_dim + (4,))


def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
    """
    Convert rotations given as quaternions to axis/angle.

    Args:
        quaternions: quaternions with real part first,
            as tensor of shape (..., 4).

    Returns:
        Rotations given as a vector in axis angle form, as a tensor
            of shape (..., 3), where the magnitude is the angle
            turned anticlockwise in radians around the vector's
            direction.
    """
    # 计算四元数的规范化部分,即虚部的模长
    norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
    
    # 计算半角(角度的一半)
    half_angles = torch.atan2(norms, quaternions[..., :1])
    
    # 计算实际角度
    angles = 2 * half_angles
    
    # 设置一个非常小的阈值,用于处理小角度的情况
    eps = 1e-6
    small_angles = angles.abs() < eps
    
    # 创建一个空的张量来存储 sin(半角) / 角度 的值
    sin_half_angles_over_angles = torch.empty_like(angles)
    
    # 对于较大的角度,直接使用公式计算 sin(半角) / 角度
    sin_half_angles_over_angles[~small_angles] = (
        torch.sin(half_angles[~small_angles]) / angles[~small_angles]
    )
    
    # 对于非常小的角度,采用更精确的近似值
    sin_half_angles_over_angles[small_angles] = (
        0.5 - (angles[small_angles] * angles[small_angles]) / 48
    )
    
    # 返回单位轴向量
    return quaternions[..., 1:] / sin_half_angles_over_angles


def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
    """
    Convert rotations given as rotation matrices to axis/angle.

    Args:
        matrix: Rotation matrices as tensor of shape (..., 3, 3).

    Returns:
        Rotations given as a vector in axis angle form, as a tensor
            of shape (..., 3), where the magnitude is the angle
            turned anticlockwise in radians around the vector's
            direction.
    """
    # 将矩阵转换为四元数,然后将四元数转换为轴角表示
    return quaternion_to_axis_angle(matrix_to_quaternion(matrix))

向量和矩阵互转

旋转向量转旋转矩阵,输入N*3

2. rotvec2mat(rotvec)

详细说明

  • 输入一个旋转向量(轴角表示),输出对应的旋转矩阵。

  • 旋转向量 rotvec 是由旋转轴和旋转角度组成的。

  • 使用 Rodrigues' 旋转公式来计算旋转矩阵。

def rotvec2mat(rotvec):
    angle = torch.linalg.norm(rotvec, dim=-1, keepdim=True)  # 计算旋转角度(向量的模)
    axis = torch.nan_to_num(rotvec / angle)  # 归一化旋转轴(防止 NaN)
    
    # 使用 Rodrigues' 公式计算旋转矩阵
    sin_axis = torch.sin(angle) * axis
    cos_angle = torch.cos(angle)
    cos1_axis = (1.0 - cos_angle) * axis
    _, axis_y, axis_z = torch.unbind(axis, dim=-1)
    cos1_axis_x, cos1_axis_y, _ = torch.unbind(cos1_axis, dim=-1)
    sin_axis_x, sin_axis_y, sin_axis_z = torch.unbind(sin_axis, dim=-1)
    
    tmp = cos1_axis_x * axis_y
    m01 = tmp - sin_axis_z
    m10 = tmp + sin_axis_z
    tmp = cos1_axis_x * axis_z
    m02 = tmp + sin_axis_y
    m20 = tmp - sin_axis_y
    tmp = cos1_axis_y * axis_z
    m12 = tmp - sin_axis_x
    m21 = tmp + sin_axis_x
    diag = cos1_axis * axis + cos_angle
    m00, m11, m22 = torch.unbind(diag, dim=-1)
    
    # 组合成旋转矩阵
    matrix = torch.stack((m00, m01, m02, m10, m11, m12, m20, m21, m22), dim=-1)
    return torch.unflatten(matrix, -1, (3, 3))  # 转换为 3x3 的矩阵

旋转矩阵转旋转向量

  • 输入一个旋转矩阵 rotmat,输出一个旋转向量,表示旋转轴和旋转角度。

  • 通过从旋转矩阵中提取信息来计算旋转向量(使用矩阵的迹、特征值等方法)。

def mat2rotvec(rotmat):
    rows = torch.unbind(rotmat, dim=-2)  # 将矩阵拆解为行
    r = [torch.unbind(row, dim=-1) for row in rows]  # 将每行再拆解为列
    
    # 计算旋转矩阵的参数
    p10p01 = r[1][0] + r[0][1]
    p10m01 = r[1][0] - r[0][1]
    p02p20 = r[0][2] + r[2][0]
    p02m20 = r[0][2] - r[2][0]
    p21p12 = r[2][1] + r[1][2]
    p21m12 = r[2][1] - r[1][2]
    p00p11 = r[0][0] + r[1][1]
    p00m11 = r[0][0] - r[1][1]
    _1p22 = 1.0 + r[2][2]
    _1m22 = 1.0 - r[2][2]
    
    trace = torch.diagonal(rotmat, dim1=-2, dim2=-1).sum(-1)  # 计算矩阵的迹
    cond0 = torch.stack((p21m12, p02m20, p10m01, 1.0 + trace), dim=-1)
    cond1 = torch.stack((_1m22 + p00m11, p10p01, p02p20, p21m12), dim=-1)
    cond2 = torch.stack((p10p01, _1m22 - p00m11, p21p12, p02p20), dim=-1)
    cond3 = torch.stack((p02p20, p21p12, _1p22 - p00p11, p10m01), dim=-1)
    
    # 判断旋转矩阵的迹,选择正确的条件
    trace_pos = (trace > 0).unsqueeze(-1)
    d00_large = torch.logical_and(r[0][0] > r[1][1], r[0][0] > r[2][2]).unsqueeze(-1)
    d11_large = (r[1][1] > r[2][2]).unsqueeze(-1)
    q = torch.where(trace_pos, cond0, torch.where(d00_large, cond1, torch.where(d11_large, cond2, cond3)))
    
    xyz, w = torch.split(q, (3, 1), dim=-1)  # 分离旋转轴和角度
    norm = torch.linalg.norm(xyz, dim=-1, keepdim=True)  # 计算旋转轴的模长
    return (torch.nan_to_num(2.0 / norm) * torch.atan2(norm, w)) * xyz  # 返回旋转向量

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值