diff --git a/pytorch3d/csrc/utils/geometry_utils.cuh b/pytorch3d/csrc/utils/geometry_utils.cuh index 53ff4f5a0..9e2979aca 100644 --- a/pytorch3d/csrc/utils/geometry_utils.cuh +++ b/pytorch3d/csrc/utils/geometry_utils.cuh @@ -177,7 +177,7 @@ __device__ inline float3 BarycentricPerspectiveCorrectionForward( const float w0_top = bary.x * z1 * z2; const float w1_top = z0 * bary.y * z2; const float w2_top = z0 * z1 * bary.z; - const float denom = w0_top + w1_top + w2_top; + const float denom = fmaxf(w0_top + w1_top + w2_top, kEpsilon); const float w0 = w0_top / denom; const float w1 = w1_top / denom; const float w2 = w2_top / denom; @@ -208,7 +208,7 @@ BarycentricPerspectiveCorrectionBackward( const float w0_top = bary.x * z1 * z2; const float w1_top = z0 * bary.y * z2; const float w2_top = z0 * z1 * bary.z; - const float denom = w0_top + w1_top + w2_top; + const float denom = fmaxf(w0_top + w1_top + w2_top, kEpsilon); // Now do backward pass const float grad_denom_top = diff --git a/pytorch3d/csrc/utils/geometry_utils.h b/pytorch3d/csrc/utils/geometry_utils.h index c8b57f531..407849d8f 100644 --- a/pytorch3d/csrc/utils/geometry_utils.h +++ b/pytorch3d/csrc/utils/geometry_utils.h @@ -198,7 +198,7 @@ inline vec3 BarycentricPerspectiveCorrectionForward( const T w0_top = bary.x * z1 * z2; const T w1_top = bary.y * z0 * z2; const T w2_top = bary.z * z0 * z1; - const T denom = w0_top + w1_top + w2_top; + const T denom = std::max(w0_top + w1_top + w2_top, kEpsilon); const T w0 = w0_top / denom; const T w1 = w1_top / denom; const T w2 = w2_top / denom; @@ -229,7 +229,7 @@ inline std::tuple, T, T, T> BarycentricPerspectiveCorrectionBackward( const T w0_top = bary.x * z1 * z2; const T w1_top = bary.y * z0 * z2; const T w2_top = bary.z * z0 * z1; - const T denom = w0_top + w1_top + w2_top; + const T denom = std::max(w0_top + w1_top + w2_top, kEpsilon); // Now do backward pass const T grad_denom_top =