Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nan in Rasterizer gradients #110

Closed
shubhtuls opened this issue Mar 12, 2020 · 8 comments
Closed

nan in Rasterizer gradients #110

shubhtuls opened this issue Mar 12, 2020 · 8 comments
Assignees
Labels
bug Something isn't working

Comments

@shubhtuls
Copy link

Description

I encountered a case where two vertices of a triangle had same screen-space XY coordinates, and this led to a 'nan' value in the gradient.

Some interesting things I noted:

  • Using CPU instead of GPU (Quadro GP100 in this case) leads to no issues
  • Using CPU vs GPU also gives different pix2face in forward pass of the rasterizer. The CPU run ignores the offending triangle's contribution, the GPU run does not. I am not sure why this happens, but I suspect this is what needs to be fixed.

Instructions To Reproduce the Issue:

import torch

import pytorch3d
import pytorch3d.renderer
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda:0")

# using cpu instead of gpu leads no nan values!
# device = torch.device("cpu")

# Note that v0 and v2 of the triangle have same x,y but different Z, so it's not a degenerate case (just a triangle that is parallel to camera)
# These are actual vertex values that occured during a run
# However, my attempts to manually create a simpler vertex location that reproduced this were unsuccessful
vs = torch.Tensor([[0.7922, -0.1992,  6.8850],[0.8408, -0.1622,  6.8568],[0.7922, -0.1992,  6.89]]).to(device)
fs = torch.Tensor([[0,1,2]]).to(device)

vs.requires_grad = True

meshes = pytorch3d.structures.Meshes([vs],[fs])
cameras = pytorch3d.renderer.OpenGLOrthographicCameras(znear=0, zfar=1, device=device)

blend_params = pytorch3d.renderer.BlendParams(sigma=1e-4, gamma=1e-4)
mask_raster_settings = pytorch3d.renderer.RasterizationSettings(
    image_size=256, 
    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
    faces_per_pixel=20,
    bin_size=0
)
mask_rasterizer = pytorch3d.renderer.MeshRasterizer(
    cameras=cameras, 
    raster_settings=mask_raster_settings
)
mask_shader = pytorch3d.renderer.SoftSilhouetteShader(blend_params=blend_params)
mask_renderer = pytorch3d.renderer.MeshRenderer(mask_rasterizer, mask_shader)

img_mask = mask_renderer(meshes)

img_mask[0,:,:,3].mean().backward()
print(vs.grad)

plt.imshow(img_mask[0,:,:,3].detach().cpu().numpy())
plt.show()
tensor([[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]], device='cuda:0')

output_3_0

## debugging via checking barycentric coords
pix2face, _, barycentric_coords, _ = mask_rasterizer(meshes)
(pix2face == 0).nonzero()[0]
tensor([  0, 147,  17,   0], device='cuda:0')
print(barycentric_coords[0, 147, 17, 0])
tensor([-3.5279e+26,  0.0000e+00,  3.5279e+26], device='cuda:0',
       grad_fn=<SelectBackward>)
@shubhtuls
Copy link
Author

Update: actually, even using a simpler vertex location can reproduce the error e.g.

vs = torch.Tensor([[0., 0.,  1.0],[0.2, 0.2,  2.0],[0., 0.,  3.0]]).to(device)

@nikhilaravi
Copy link
Contributor

@shubhtuls thanks for the detailed explanation of the issue. I will try to reproduce the error as described and get back to you!

@nikhilaravi nikhilaravi self-assigned this Mar 12, 2020
@shubhtuls
Copy link
Author

shubhtuls commented Mar 12, 2020

I think this is related to some precision issues when computing the face areas. Adding a statement to print the 'face_area' in the cuda rasterizer implementation here shows that it is of the order of 1e-9, which is greater than the kEpsilon=1e-30 used to check for zero area.

I unblocked on my end by additionally defining a 'kEpsilonFace=1e-7' in these lines and using that for the zero area check, but I'm not sure if this is the ideal solution.

@gkioxari
Copy link
Contributor

gkioxari commented Mar 13, 2020

Small face areas are such a headache! I guess in both the example faces above you have almost 0 face areas. Did nans disappear with 1e-7?

@shubhtuls
Copy link
Author

@gkioxari - so far, yes!

@nikhilaravi nikhilaravi added the bug Something isn't working label Mar 20, 2020
@tomguluson92
Copy link

@gkioxari - so far, yes!

Thanks for your solution, can I realize your solution aims to avoid gradient vanish problem through a more rigid settings according to each face area?

@gkioxari
Copy link
Contributor

gkioxari commented Apr 8, 2020

@tomguluson92 In a follow up diff we are setting the kEpsilon value to 1e-8. Yes the issue is the small face areas that are determined based on that value.

@nikhilaravi
Copy link
Contributor

nikhilaravi commented Apr 24, 2020

This has been fixed by 487d4d6.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants