Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UserWarning: R is not a valid rotation matrix (using look_at_rotation) #595

Closed
fairydora opened this issue Mar 11, 2021 · 4 comments
Closed
Assignees
Labels
how to How to use PyTorch3D in my project

Comments

@fairydora
Copy link

fairydora commented Mar 11, 2021

Hello, :)

I've taken the tutorial "camera_position_optimization_with_differentiable_rendering" and

  1. ran it in collab (The optimisation starts from the position defined by you inside the Model: [3.0, 6.9, +2.5])
  2. ran it in pycharm (start: [3.0, 6.9, +2.5])
  3. ran my own project based on it, with slightly different start and target camera settings:
  • start_cam_T = [0.3, 1.0, 4.5]
  • target: (distance=3.8, elevation=-3.0, azimuth=0.0,), cam_T = [[-0.0500, -0.1700, 3.8000]],

In each of these tests I get the error "UserWarning: R is not a valid rotation matrix" at random times.
At first I thought it was a problem with the camera aligning itself with the up vector, but it doesn't seem to be the case (the default up vector for look_at_rotation is (0,1,0).

I'm using the following versions:
-pytorch 1.8.0
-pytorch3d 0.4.0
-python 3.6.8

I would like to ask how I can work around this or if I am doing something wrong... this issue is rather urgent for me.

To replicate the problem in collab, you need to run it a few times starting from "3. Initialize the model and optimizer".
It always runs correctly the first time. I can also send the version I've copied to pycharm.

This is the last few printouts from your collab version of the tutorial:
The problem happens on iteration 90 in this case, but it's random, can happen almost immediately:

model.camera_position: tensor([0.2993, 3.9161, 1.3367], device='cuda:0')
R: tensor([[[-0.9758, -0.2063, -0.0721],
[ 0.0000, 0.3302, -0.9439],
[ 0.2185, -0.9211, -0.3222]]], device='cuda:0')
T: tensor([[-0.0000, -0.0000, 4.1488]], device='cuda:0')

model.camera_position: tensor([-0.2256, 3.2020, 1.6944], device='cuda:0')
R: tensor([[[-0.9913, 0.1164, 0.0621],
[ 0.0000, 0.4709, -0.8822],
[-0.1320, -0.8745, -0.4668]]], device='cuda:0')
T: tensor([[1.4901e-08, -0.0000e+00, 3.6297e+00]], device='cuda:0')

model.camera_position: tensor([-0.0372, 2.4358, 1.8689], device='cuda:0')
R: tensor([[[-0.9998, 0.0158, 0.0121],
[ 0.0000, 0.6088, -0.7933],
[-0.0199, -0.7932, -0.6087]]], device='cuda:0')
T: tensor([[-0.0000e+00, -2.3842e-07, 3.0704e+00]], device='cuda:0')

model.camera_position: tensor([0.0960, 2.1915, 2.1579], device='cuda:0')
R: tensor([[[-0.9990, -0.0316, -0.0312],
[ 0.0000, 0.7020, -0.7122],
[ 0.0444, -0.7115, -0.7013]]], device='cuda:0')
T: tensor([[-0.0000e+00, -1.1921e-07, 3.0771e+00]], device='cuda:0')
/usr/local/lib/python3.7/dist-packages/pytorch3d/transforms/transform3d.py:726: UserWarning: R is not a valid rotation matrix
warnings.warn(msg)

model.camera_position: tensor([nan, nan, nan], device='cuda:0')
R: tensor([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]], device='cuda:0')
T: tensor([[nan, nan, nan]], device='cuda:0')

This is the output from my own optimisation, I can replicate the problem almost every time I run it:

pos: tensor([-0.5910, -0.5706, 2.9129], device='cuda:0')
T: tensor([[-0.0000, -0.0000, 3.0265]], device='cuda:0')
R: tensor([[[-0.9800, -0.0375, 0.1953],
[ 0.0000, 0.9821, 0.1885],
[-0.1988, 0.1848, -0.9625]]], device='cuda:0')
:: loss: tensor(2854.3237, device='cuda:0')


pos: tensor([-0.5813, -0.5776, 2.9002], device='cuda:0')
T: tensor([[-0.0000e+00, -5.9605e-08, 3.0137e+00]], device='cuda:0')
R: tensor([[[-0.9805, -0.0377, 0.1929],
[ 0.0000, 0.9815, 0.1916],
[-0.1965, 0.1879, -0.9623]]], device='cuda:0')
:: loss: tensor(2838.6309, device='cuda:0')


pos: tensor([nan, nan, nan], device='cuda:0')
T: tensor([[nan, nan, nan]], device='cuda:0')
R: tensor([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]], device='cuda:0')
C:\anaconda\envs\bhava4\lib\site-packages\pytorch3d\transforms\transform3d.py:726: UserWarning: R is not a valid rotation matrix
warnings.warn(msg)
:: loss: tensor(10243.0742, device='cuda:0')
R_not_valid_rotation_matrix

@nikhilaravi
Copy link
Contributor

@fairydora this is where the warning is coming from:

def _check_valid_rotation_matrix(R, tol: float = 1e-7):
"""
Determine if R is a valid rotation matrix by checking it satisfies the
following conditions:
``RR^T = I and det(R) = 1``
Args:
R: an (N, 3, 3) matrix
Returns:
None
Emits a warning if R is an invalid rotation matrix.
"""
N = R.shape[0]
eye = torch.eye(3, dtype=R.dtype, device=R.device)
eye = eye.view(1, 3, 3).expand(N, -1, -1)
orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol)
det_R = torch.det(R)
no_distortion = torch.allclose(det_R, torch.ones_like(det_R))
if not (orthogonal and no_distortion):
msg = "R is not a valid rotation matrix"
warnings.warn(msg)
return

Is the warning affecting the result? Does the optimization complete or are you getting NaNs? Can you also check #585 to see if you have the same issue?

@nikhilaravi nikhilaravi self-assigned this Mar 11, 2021
@nikhilaravi nikhilaravi added the how to How to use PyTorch3D in my project label Mar 11, 2021
@fairydora
Copy link
Author

fairydora commented Mar 11, 2021

Hello Nikhila :)
I get NaNs for the camera position and the rotation matrix - I have posted my printouts above, you can see the tensor values just before this happened and then when it happened. You will also get NaNs in the Collab tutorial if you run it a few times. (I've also included the printouts for this above)

The result of this is a totally white render, so I can't finish the optimisation.
I would really appreciate any help, it's very urgent for me to get this to work, and it's totally unpredictable.

I'll immediately test the solution from #585 and report the outcome, thanks!

@fairydora
Copy link
Author

fairydora commented Mar 11, 2021

hi Nikhila,
the solution #585 ( adding perspective_correct=False in the RasterizationSettings) seems to have solved the issue for me :) thank you very much!

@nikhilaravi
Copy link
Contributor

Great! We'll fix the issue in master!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
how to How to use PyTorch3D in my project
Projects
None yet
Development

No branches or pull requests

2 participants