Skip to content

Commit

Permalink
small fix for iou3d
Browse files Browse the repository at this point in the history
Summary:
A small numerical fix for IoU for 3D boxes, fixes GH #992

* Adds a check for boxes with zero side areas (invalid boxes)
* Fixes numerical issue when two boxes have coplanar sides

Reviewed By: nikhilaravi

Differential Revision: D33195691

fbshipit-source-id: 8a34b4d1f1e5ec2edb6d54143930da44bdde0906
  • Loading branch information
gkioxari authored and facebook-github-bot committed Dec 18, 2021
1 parent 069c9fd commit ccfb72c
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 4 deletions.
3 changes: 2 additions & 1 deletion pytorch3d/csrc/iou_box3d/iou_box3d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ __global__ void IoUBox3DKernel(
for (int b2 = 0; b2 < box2_count; ++b2) {
const bool is_coplanar =
IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
if (is_coplanar) {
const float area = FaceArea(box1_intersect[b1]);
if ((is_coplanar) && (area > kEpsilon)) {
tri2_keep[b2].keep = false;
}
}
Expand Down
3 changes: 2 additions & 1 deletion pytorch3d/csrc/iou_box3d/iou_box3d_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
for (int b2 = 0; b2 < box2_intersect.size(); ++b2) {
const bool is_coplanar =
IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
if (is_coplanar) {
const float area = FaceArea(box1_intersect[b1]);
if ((is_coplanar) && (area > kEpsilon)) {
tri2_keep[b2] = 0;
}
}
Expand Down
20 changes: 20 additions & 0 deletions pytorch3d/csrc/iou_box3d/iou_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,26 @@ FaceNormal(const float3 v0, const float3 v1, const float3 v2) {
return n;
}

// The area of the face defined by vertices (v0, v1, v2)
// Define e0 to be the edge connecting (v1, v0)
// Define e1 to be the edge connecting (v2, v0)
// Area is the norm of the cross product of e0, e1 divided by 2.0
//
// Args
// tri: FaceVerts of float3 coordinates of the vertices of the face
//
// Returns
// float: area for the face
//
__device__ inline float FaceArea(const FaceVerts& tri) {
// Get verts for face 1
const float3 v0 = tri.v0;
const float3 v1 = tri.v1;
const float3 v2 = tri.v2;
const float3 n = cross(v1 - v0, v2 - v0);
return norm(n) / 2.0;
}

// The normal of a box plane defined by the verts in `plane` with
// the centroid of the box given by `center`.
// Args
Expand Down
20 changes: 20 additions & 0 deletions pytorch3d/csrc/iou_box3d/iou_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,26 @@ inline vec3<float> FaceNormal(vec3<float> v0, vec3<float> v1, vec3<float> v2) {
return n;
}

// The area of the face defined by vertices (v0, v1, v2)
// Define e0 to be the edge connecting (v1, v0)
// Define e1 to be the edge connecting (v2, v0)
// Area is the norm of the cross product of e0, e1 divided by 2.0
//
// Args
// tri: vec3 coordinates of the vertices of the face
//
// Returns
// float: area for the face
//
inline float FaceArea(const std::vector<vec3<float>>& tri) {
// Get verts for face
const vec3<float> v0 = tri[0];
const vec3<float> v1 = tri[1];
const vec3<float> v2 = tri[2];
const vec3<float> n = cross(v1 - v0, v2 - v0);
return norm(n) / 2.0;
}

// The normal of a box plane defined by the verts in `plane` with
// the centroid of the box given by `center`.
// Args
Expand Down
24 changes: 24 additions & 0 deletions pytorch3d/ops/iou_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,28 @@ def _check_coplanar(boxes: torch.Tensor, eps: float = 1e-4) -> None:
return


def _check_nonzero(boxes: torch.Tensor, eps: float = 1e-4) -> None:
"""
Checks that the sides of the box have a non zero area
"""
faces = torch.tensor(_box_triangles, dtype=torch.int64, device=boxes.device)
# pyre-fixme[16]: `boxes` has no attribute `index_select`.
verts = boxes.index_select(index=faces.view(-1), dim=1)
B = boxes.shape[0]
T, V = faces.shape
# (B, T, 3, 3) -> (B, T, 3)
v0, v1, v2 = verts.reshape(B, T, V, 3).unbind(2)

normals = torch.cross(v1 - v0, v2 - v0, dim=-1) # (B, T, 3)
face_areas = normals.norm(dim=-1) / 2

if (face_areas < eps).any().item():
msg = "Planes have zero areas"
raise ValueError(msg)

return


class _box3d_overlap(Function):
"""
Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
Expand Down Expand Up @@ -138,6 +160,8 @@ def box3d_overlap(

_check_coplanar(boxes1, eps)
_check_coplanar(boxes2, eps)
_check_nonzero(boxes1, eps)
_check_nonzero(boxes2, eps)

# pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`.
vol, iou = _box3d_overlap.apply(boxes1, boxes2)
Expand Down
136 changes: 134 additions & 2 deletions tests/test_iou_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def _test_iou(self, overlap_fn, device):
self.assertClose(
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
)
# symmetry
vol, iou = overlap_fn(box2[None], box1[None])
self.assertClose(
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
)

# 3rd test
dd = random.random()
Expand All @@ -119,6 +124,11 @@ def _test_iou(self, overlap_fn, device):
self.assertClose(
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
)
# symmetry
vol, _ = overlap_fn(box2[None], box1[None])
self.assertClose(
vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
)

# 4th test
ddx, ddy, ddz = random.random(), random.random(), random.random()
Expand All @@ -132,6 +142,16 @@ def _test_iou(self, overlap_fn, device):
dtype=vol.dtype,
),
)
# symmetry
vol, _ = overlap_fn(box2[None], box1[None])
self.assertClose(
vol,
torch.tensor(
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
device=vol.device,
dtype=vol.dtype,
),
)

# Also check IoU is 1 when computing overlap with the same shifted box
vol, iou = overlap_fn(box2[None], box2[None])
Expand All @@ -152,6 +172,16 @@ def _test_iou(self, overlap_fn, device):
dtype=vol.dtype,
),
)
# symmetry
vol, _ = overlap_fn(box2r[None], box1r[None])
self.assertClose(
vol,
torch.tensor(
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
device=vol.device,
dtype=vol.dtype,
),
)

# 6th test
ddx, ddy, ddz = random.random(), random.random(), random.random()
Expand All @@ -170,6 +200,17 @@ def _test_iou(self, overlap_fn, device):
),
atol=1e-7,
)
# symmetry
vol, _ = overlap_fn(box2r[None], box1r[None])
self.assertClose(
vol,
torch.tensor(
[[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
device=vol.device,
dtype=vol.dtype,
),
atol=1e-7,
)

# 7th test: hand coded example and test with meshlab output

Expand Down Expand Up @@ -214,6 +255,10 @@ def _test_iou(self, overlap_fn, device):
vol, iou = overlap_fn(box1r[None], box2r[None])
self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1)
self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1)
# symmetry
vol, iou = overlap_fn(box2r[None], box1r[None])
self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1)
self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1)

# 8th test: compare with sampling
# create box1
Expand All @@ -232,14 +277,20 @@ def _test_iou(self, overlap_fn, device):
iou_sampling = self._box3d_overlap_sampling_batched(
box1r[None], box2r[None], num_samples=10000
)

self.assertClose(iou, iou_sampling, atol=1e-2)
# symmetry
vol, iou = overlap_fn(box2r[None], box1r[None])
self.assertClose(iou, iou_sampling, atol=1e-2)

# 9th test: non overlapping boxes, iou = 0.0
box2 = box1 + torch.tensor([[0.0, 100.0, 0.0]], device=device)
vol, iou = overlap_fn(box1[None], box2[None])
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
# symmetry
vol, iou = overlap_fn(box2[None], box1[None])
self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))

# 10th test: Non coplanar verts in a plane
box10 = box1 + torch.rand((8, 3), dtype=torch.float32, device=device)
Expand Down Expand Up @@ -284,6 +335,56 @@ def _test_iou(self, overlap_fn, device):
vols, ious = overlap_fn(box_skew_1[None], box_skew_2[None])
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
# symmetry
vols, ious = overlap_fn(box_skew_2[None], box_skew_1[None])
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)

# 12th test: Zero area bounding box (from GH issue #992)
box12a = torch.tensor(
[
[-1.0000, -1.0000, -0.5000],
[1.0000, -1.0000, -0.5000],
[1.0000, 1.0000, -0.5000],
[-1.0000, 1.0000, -0.5000],
[-1.0000, -1.0000, 0.5000],
[1.0000, -1.0000, 0.5000],
[1.0000, 1.0000, 0.5000],
[-1.0000, 1.0000, 0.5000],
],
device=device,
dtype=torch.float32,
)

box12b = torch.tensor(
[
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
],
device=device,
dtype=torch.float32,
)
msg = "Planes have zero areas"
with self.assertRaisesRegex(ValueError, msg):
overlap_fn(box12a[None], box12b[None])
# symmetry
with self.assertRaisesRegex(ValueError, msg):
overlap_fn(box12b[None], box12a[None])

# 13th test: From GH issue #992
# Zero area coplanar face after intersection
ctrs = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0]])
whl = torch.tensor([[2.0, 2.0, 2.0], [2.0, 2, 2]])
box13a = TestIoU3D.create_box(ctrs[0], whl[0])
box13b = TestIoU3D.create_box(ctrs[1], whl[1])
vol, iou = overlap_fn(box13a[None], box13b[None])
self.assertClose(vol, torch.tensor([[2.0]], device=vol.device, dtype=vol.dtype))

def _test_real_boxes(self, overlap_fn, device):
data_filename = "./real_boxes.pkl"
Expand Down Expand Up @@ -577,6 +678,13 @@ def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor:
msg = "Plane vertices are not coplanar"
raise ValueError(msg)

# Check all faces have non zero area
area1 = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
area2 = torch.cross(v3 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
if (area1 < eps).any().item() or (area2 < eps).any().item():
msg = "Planes have zero areas"
raise ValueError(msg)

# We can write: `ctr = v0 + a * e0 + b * e1 + c * n`, (1).
# With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product,
# since that e0 is orthogonal to n. Same for e1.
Expand Down Expand Up @@ -607,6 +715,27 @@ def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor:
return n


def tri_verts_area(tri_verts: torch.Tensor) -> torch.Tensor:
"""
Computes the area of the triangle faces in tri_verts
Args:
tri_verts: tensor of shape (T, 3, 3)
Returns:
areas: the area of the triangles (T, 1)
"""
add_dim = False
if tri_verts.ndim == 2:
tri_verts = tri_verts.unsqueeze(0)
add_dim = True

v0, v1, v2 = tri_verts.unbind(1)
areas = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2.0

if add_dim:
areas = areas[0]
return areas


def box_volume(box: torch.Tensor) -> torch.Tensor:
"""
Computes the volume of each box in boxes.
Expand Down Expand Up @@ -988,7 +1117,10 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
keep2 = torch.ones((tri_verts2.shape[0],), device=device, dtype=torch.bool)
for i1 in range(tri_verts1.shape[0]):
for i2 in range(tri_verts2.shape[0]):
if coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2]):
if (
coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2])
and tri_verts_area(tri_verts1[i1]) > 1e-4
):
keep2[i2] = 0
keep2 = keep2.nonzero()[:, 0]
tri_verts2 = tri_verts2[keep2]
Expand Down

0 comments on commit ccfb72c

Please sign in to comment.