-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
I'm looking at pytorch3d's source code and this line looks a bit worrying:
| axis_angles[near_pi] = angles[near_pi] * n / torch.norm(n) |
I've attached the exact code snippet below:
def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool = False) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to axis/angle.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
fast: Whether to use the new faster implementation (based on the
Rodrigues formula) instead of the original implementation (which
first converted to a quaternion and then back to a rotation matrix).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
if not fast:
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
omegas = torch.stack(
[
matrix[..., 2, 1] - matrix[..., 1, 2],
matrix[..., 0, 2] - matrix[..., 2, 0],
matrix[..., 1, 0] - matrix[..., 0, 1],
],
dim=-1,
)
norms = torch.norm(omegas, p=2, dim=-1, keepdim=True)
traces = torch.diagonal(matrix, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1)
angles = torch.atan2(norms, traces - 1)
zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device)
omegas = torch.where(torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas)
near_pi = angles.isclose(angles.new_full((1,), torch.pi)).squeeze(-1)
axis_angles = torch.empty_like(omegas)
axis_angles[~near_pi] = (
0.5 * omegas[~near_pi] / torch.sinc(angles[~near_pi] / torch.pi)
)
# this derives from: nnT = (R + 1) / 2
n = 0.5 * (
matrix[near_pi][..., 0, :]
+ torch.eye(1, 3, dtype=matrix.dtype, device=matrix.device)
)
axis_angles[near_pi] = angles[near_pi] * n / torch.norm(n)
return axis_anglesThe idea is that we're using Rodrigues formula to compute axis angles from a matrix. It looks like that we use torch.norm(n) which ignores the batch dimension? I suppose the right formula should be to use torch.linalg.vector_norm on the last dimension to normalize the vector n?
I might be wrong, as I'm not super familiar with the Rodrigues formula but happy to hear if there's actually a problem.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working