Skip to content

Commit

Permalink
remove requires_grad from random rotations
Browse files Browse the repository at this point in the history
Summary: Because rotations and (rotation) quaternions live on curved manifolds, it doesn't make sense to optimize them directly. Having a prominent option to require gradient on random ones may cause people to try, and isn't particularly useful.

Reviewed By: theschnitz

Differential Revision: D29160734

fbshipit-source-id: fc9e320672349fe334747c5b214655882a460a62
  • Loading branch information
bottler authored and facebook-github-bot committed Jun 21, 2021
1 parent 31c448a commit ce60d4b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 23 deletions.
26 changes: 6 additions & 20 deletions pytorch3d/transforms/rotation_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,7 @@ def matrix_to_euler_angles(matrix, convention: str):
return torch.stack(o, -1)


def random_quaternions(
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
):
def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None):
"""
Generate random quaternions representing rotations,
i.e. versors with nonnegative real part.
Expand All @@ -294,21 +292,17 @@ def random_quaternions(
dtype: Type to return.
device: Desired device of returned tensor. Default:
uses the current device for the default tensor type.
requires_grad: Whether the resulting tensor should have the gradient
flag set.
Returns:
Quaternions as tensor of shape (N, 4).
"""
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
o = torch.randn((n, 4), dtype=dtype, device=device)
s = (o * o).sum(1)
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
return o


def random_rotations(
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
):
def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None):
"""
Generate random rotations as 3x3 rotation matrices.
Expand All @@ -317,35 +311,27 @@ def random_rotations(
dtype: Type to return.
device: Device of returned tensor. Default: if None,
uses the current device for the default tensor type.
requires_grad: Whether the resulting tensor should have the gradient
flag set.
Returns:
Rotation matrices as tensor of shape (n, 3, 3).
"""
quaternions = random_quaternions(
n, dtype=dtype, device=device, requires_grad=requires_grad
)
quaternions = random_quaternions(n, dtype=dtype, device=device)
return quaternion_to_matrix(quaternions)


def random_rotation(
dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
):
def random_rotation(dtype: Optional[torch.dtype] = None, device=None):
"""
Generate a single random 3x3 rotation matrix.
Args:
dtype: Type to return
device: Device of returned tensor. Default: if None,
uses the current device for the default tensor type
requires_grad: Whether the resulting tensor should have the gradient
flag set
Returns:
Rotation matrix as tensor of shape (3, 3).
"""
return random_rotations(1, dtype, device, requires_grad)[0]
return random_rotations(1, dtype, device)[0]


def standardize_quaternion(quaternions):
Expand Down
9 changes: 6 additions & 3 deletions tests/test_rotation_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def test_to_quat(self):

def test_quat_grad_exists(self):
"""Quaternion calculations are differentiable."""
rotation = random_rotation(requires_grad=True)
rotation = random_rotation()
rotation.requires_grad = True
modified = quaternion_to_matrix(matrix_to_quaternion(rotation))
[g] = torch.autograd.grad(modified.sum(), rotation)
self.assertTrue(torch.isfinite(g).all())
Expand Down Expand Up @@ -131,7 +132,8 @@ def test_to_euler(self):

def test_euler_grad_exists(self):
"""Euler angle calculations are differentiable."""
rotation = random_rotation(dtype=torch.float64, requires_grad=True)
rotation = random_rotation(dtype=torch.float64)
rotation.requires_grad = True
for convention in self._all_euler_angle_conventions():
euler_angles = matrix_to_euler_angles(rotation, convention)
mdata = euler_angles_to_matrix(euler_angles, convention)
Expand Down Expand Up @@ -218,7 +220,8 @@ def test_to_axis_angle(self):

def test_quaternion_application(self):
"""Applying a quaternion is the same as applying the matrix."""
quaternions = random_quaternions(3, torch.float64, requires_grad=True)
quaternions = random_quaternions(3, torch.float64)
quaternions.requires_grad = True
matrices = quaternion_to_matrix(quaternions)
points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True)
transform1 = quaternion_apply(quaternions, points)
Expand Down

0 comments on commit ce60d4b

Please sign in to comment.