Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN when using MeshRasterizer #561

Open
pengsongyou opened this issue Feb 12, 2021 · 32 comments
Open

NaN when using MeshRasterizer #561

pengsongyou opened this issue Feb 12, 2021 · 32 comments
Assignees
Labels
do-not-reap Do not delete this pull request or issue due to inactivity. potential-bug Potential bug to flag an issue that needs to be looked into

Comments

@pengsongyou
Copy link

Description

I installed the latest pytorch3d 0.4 and tried to run the fit_textured_mesh tutorial under the Mesh prediction via silhouette rendering section. The loss becomes NaN after around 200 iterations (4 out of 5 times I can reproduce this issue).

I also tried pytorch3d 0.3 (built from source in December), and this issue never happened. Therefore, there might be some issues in the latest update for Mesh Rasterizer.

Reproduce

Install pytorch 1.7.1

conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

Install pytorch3d using wheels for linux instruction

pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu102_pyt171/download.html

And then simply run the fit_textured_mesh tutorial, you should be able to reproduce results. I can obtain the nan 4 out 5 times I run.

Best,
Songyou

@nikhilaravi
Copy link
Contributor

Thanks @pengsongyou for reporting this issue! We'll look into it asap.

@nikhilaravi nikhilaravi added the potential-bug Potential bug to flag an issue that needs to be looked into label Feb 12, 2021
@nikhilaravi nikhilaravi self-assigned this Feb 12, 2021
@nikhilaravi
Copy link
Contributor

@pengsongyou I was able to reproduce the error. To resolve the issue in the tutorial add perspective_correct=False in the RasterizationSettings for the rasterizer. In v0.4 we changed this to be automatically inferred from the camera type but there seems to be some instability due to this. We will debug what is happening!

@pengsongyou
Copy link
Author

pengsongyou commented Feb 12, 2021

Great, now it indeed seems working, thanks a lot! I have been always using the perspective camera model, but I did not need to turn perspective_correct=False when I was using 0.3 because no issue was found. Just wondering if you could explain why we need to make it explicitly False now in 0.4?

Thanks so much in advance!

Best,
Songyou

@nikhilaravi
Copy link
Contributor

@pengsongyou the perspective_correct setting basically ensures that the barycentric coordinates are correct under a perspective camera. This is not corrected in other differentiable renderers like SoftRas/NMR/DIB-R which assume that the perspective effects are small. In the previous version of PyTorch3D this was an optional setting but in the most recent release we decided to set it based on the type of the camera. We will investigate why this is causing nans in the optimization.

@JudyYe
Copy link
Contributor

JudyYe commented Apr 18, 2021

Hi, I have encountered similar NaN error in rasterizer :/. I just wanna provide another example that might help the team to debug. But as far as right now, perspective_correct=False / Orthogonal camera solves this particular case (Thanks Nikhila and Georgia)

NaN seems to happen when the rendered faces is parallel to the ray. (maybe relevant to the previous issue #110.)
I provided my triangle that caused nan fragments in the file: triangle.pkl, together with my script:

    fname = 'triangle.pkl'
    device = 'cuda:0'
    with open(fname, 'rb') as fp:
        obj = pickle.load(fp)
        triangle = obj['tri']
        triangle = triangle.to(device)

    cameras = PerspectiveCameras(100., device=device)
    blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
    dist_eps = 1e-6
    raster_settings = RasterizationSettings(
        image_size=224,
        blur_radius=np.log(1. / dist_eps - 1.) * blend_params.sigma,
        faces_per_pixel=100,
        # perspective_correct=False, # this seems solve the nan error at least for this 
    )
    rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings).to(device)
    fragments = rasterizer(triangle)
    print(fragments.zbuf.isnan().any() ,fragments.bary_coords.isnan().any())
    # True, True for me

The triangle looks like this in 3D:
3d
and this in screen space:
2d
visualization code:

from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt
import pickle
fname = '/tmp/transfer/vis/triangle.pkl'
with open(fname, 'rb')  as fp:
	triangle = pickle.load(fp)
verts = triangle['verts']
verts2d = triangle['verts_screen']

def refract_verts(verts):
	verts = np.vstack([verts, verts[0:1]])
	return verts
verts = refract_verts(verts)
verts2d = refract_verts(verts2d)

fig = plt.figure()
ax = plt.axes(projection='3d')

ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')

ax.plot3D(verts[:, 0], verts[:, 1], verts[:, 2], 'gray')

fig = plt.figure()
plt.plot(verts2d[:, 0], verts2d[:, 1])
plt.show()

Thanks and good luck.

@github-actions
Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Jun 23, 2021
@github-actions
Copy link

This issue was closed because it has been stalled for 5 days with no activity.

@nikhilaravi nikhilaravi reopened this Jun 29, 2021
@nikhilaravi
Copy link
Contributor

I will look into this issue! Thanks for the explanation @JudyYe.

@tals
Copy link

tals commented Jul 23, 2021

Hey - I am experiencing the same issue (nans after about 200 iteration steps).
perspective_correct=False doesn't seem to help though :(

EDIT: I didn't notice they were multiple RasterizationSettings instances. Works now!

$ pip list | grep torch
pytorch3d                         0.4.0
torch                             1.7.1+cu110

@jbohnslav
Copy link

I can confirm both that this bug still exists in 0.5.0, and that setting perspective_correct=False removes the issue. I ran my code with torch anomaly detection on, not sure if it's helpful. Here's the relevant portion of the anomaly detection output:

  File "/home/jim/Documents/python/pytorch3d/pytorch3d/renderer/mesh/renderer.py", line 59, in forward
    fragments = self.rasterizer(meshes_world, **kwargs)
  File "/home/jim/anaconda3/envs/armo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/jim/Documents/python/pytorch3d/pytorch3d/renderer/mesh/rasterizer.py", line 171, in forward
    pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
  File "/home/jim/Documents/python/pytorch3d/pytorch3d/renderer/mesh/rasterize_meshes.py", line 231, in rasterize_meshes
    pix_to_face, zbuf, barycentric_coords, dists = _RasterizeFaceVerts.apply(
 (function _print_stack)

And the error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_89528/4272499749.py in <module>
     80     optimizer.zero_grad()
---> 81     loss.backward()
     82     optimizer.step()

~/anaconda3/envs/armo/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    219                 retain_graph=retain_graph,
    220                 create_graph=create_graph)
--> 221         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    222 
    223     def register_hook(self, hook):

~/anaconda3/envs/armo/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    128         retain_graph = create_graph
    129 
--> 130     Variable._execution_engine.run_backward(
    131         tensors, grad_tensors_, retain_graph, create_graph,
    132         allow_unreachable=True)  # allow_unreachable flag

RuntimeError: Function '_RasterizeFaceVertsBackward' returned nan values in its 0th output.

@dukleryoni
Copy link

Hi same here on perspective_correct=False fixing the issue. In my settings, I also have that (even when perspective_correct=True) if the mesh and renderer are on on the CPU, I no longer get NaNs.

Additionally, I was also wondering if having FoVPerspectiveCameras camera + perspective_correct=False for the rasterization setting is equivalent to having a weak perspective camera?

@rubenverhack
Copy link

rubenverhack commented Sep 7, 2021

I can confirm that this bug is present in v0.5.0 using the out of the box tutorial "camera_position_optimization_with_differentiable_rendering". perspective_correct=False fixes the issue.

@github-actions
Copy link

github-actions bot commented Oct 8, 2021

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Oct 8, 2021
@bottler bottler removed the Stale label Oct 8, 2021
facebook-github-bot pushed a commit that referenced this issue Oct 22, 2021
Summary:
#561
#790
Divide by zero fix (NaN fix).  When perspective_correct=True, BarycentricPerspectiveCorrectionForward and BarycentricPerspectiveCorrectionBackward in ../csrc/utils/geometry_utils.cuh are called.  The denominator (denom) values should not be allowed to go to zero. I'm able to resolve this issue locally with this PR and submit it for the team's review.

Pull Request resolved: #891

Reviewed By: patricklabatut

Differential Revision: D31829695

Pulled By: bottler

fbshipit-source-id: a3517b8362f6e60d48c35731258d8ce261b1d912
@github-actions
Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Jan 13, 2022
@bottler bottler removed the Stale label Jan 13, 2022
@github-actions
Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions
Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Mar 17, 2022
@bottler bottler removed the Stale label Mar 18, 2022
@TimmmYang
Copy link

Hello, the NaN problem still exists. In my cases, I use RasterzationSettings as follows:

raster_settings = RasterizationSettings(
            image_size=(self.img_h, self.img_w),
            blur_radius=0,
            faces_per_pixel=1,
            perspective_correct=False,
        )    

My environment:

# Name                    Version                   Build  Channel
pytorch                   1.8.1           py3.7_cuda11.1_cudnn8.0.5_0    pytorch
pytorch3d                 0.6.1                     dev_0    <develop>

@github-actions
Copy link

github-actions bot commented May 1, 2022

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label May 1, 2022
@github-actions
Copy link

github-actions bot commented May 7, 2022

This issue was closed because it has been stalled for 5 days with no activity.

@github-actions github-actions bot closed this as completed May 7, 2022
@bottler bottler removed the Stale label May 9, 2022
@bottler bottler reopened this May 9, 2022
@github-actions
Copy link

github-actions bot commented Jun 9, 2022

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Jun 9, 2022
@bottler bottler removed the Stale label Jun 13, 2022
@d4l3k
Copy link
Member

d4l3k commented Jun 20, 2022

I'm also running into periodic NaNs w/ the mesh rasterizer. Seems to occur with the HardDepthShader in #1208 which is about as simple as you can get shading wise

@d4l3k
Copy link
Member

d4l3k commented Jun 22, 2022

I turned on anomaly detection and traced those NaNs back to transform_points denom correction in

https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/renderer/mesh/rasterizer.py#L193-L196

Might be a good idea to change eps so it's not None. Seems to be set in a lot of places so None seems like a bad default given the potential bad behavior. I set it to eps=1e-8 and seems to have solved it. Implicitron looks like it's set to 1e-2 which seems very large

@github-actions
Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Jul 23, 2022
@bottler bottler removed the Stale label Jul 25, 2022
@srph25
Copy link

srph25 commented Apr 21, 2023

I encountered a similar problem and I think d4l3k is right. Calling renderer(meshes, eps=1e-8) or similarly for point clouds solved the issue for me.

@relh
Copy link
Contributor

relh commented Sep 20, 2023

I encountered a similar problem and I think d4l3k is right. Calling renderer(meshes, eps=1e-8) or similarly for point clouds solved the issue for me.

This solved my problem too! Thanks so much, as setting perspective_correct=False didn't do it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
do-not-reap Do not delete this pull request or issue due to inactivity. potential-bug Potential bug to flag an issue that needs to be looked into
Projects
None yet
Development

No branches or pull requests