Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add None option for chamfer distance point reduction #1605

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 42 additions & 27 deletions pytorch3d/loss/chamfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@


def _validate_chamfer_reduction_inputs(
batch_reduction: Union[str, None], point_reduction: str
batch_reduction: Union[str, None], point_reduction: Union[str, None]
) -> None:
"""Check the requested reductions are valid.

Args:
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["mean", "sum"].
points, can be one of ["mean", "sum"] or None.
"""
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
if point_reduction not in ["mean", "sum"]:
raise ValueError('point_reduction must be one of ["mean", "sum"]')
if point_reduction is not None and point_reduction not in ["mean", "sum"]:
raise ValueError('point_reduction must be one of ["mean", "sum"] or None')
if point_reduction is None and batch_reduction is not None:
raise ValueError("Batch reduction must be None if point_reduction is None")


def _handle_pointcloud_input(
Expand Down Expand Up @@ -77,7 +79,7 @@ def _chamfer_distance_single_direction(
y_normals,
weights,
batch_reduction: Union[str, None],
point_reduction: str,
point_reduction: Union[str, None],
norm: int,
abs_cosine: bool,
):
Expand Down Expand Up @@ -130,26 +132,28 @@ def _chamfer_distance_single_direction(

if weights is not None:
cham_norm_x *= weights.view(N, 1)
cham_norm_x = cham_norm_x.sum(1) # (N,)

# Apply point reduction
cham_x = cham_x.sum(1) # (N,)
if point_reduction == "mean":
x_lengths_clamped = x_lengths.clamp(min=1)
cham_x /= x_lengths_clamped
if point_reduction is not None:
# Apply point reduction
cham_x = cham_x.sum(1) # (N,)
if return_normals:
cham_norm_x /= x_lengths_clamped
cham_norm_x = cham_norm_x.sum(1) # (N,)
if point_reduction == "mean":
x_lengths_clamped = x_lengths.clamp(min=1)
cham_x /= x_lengths_clamped
if return_normals:
cham_norm_x /= x_lengths_clamped

if batch_reduction is not None:
# batch_reduction == "sum"
cham_x = cham_x.sum()
if return_normals:
cham_norm_x = cham_norm_x.sum()
if batch_reduction == "mean":
div = weights.sum() if weights is not None else max(N, 1)
cham_x /= div
if batch_reduction is not None:
# batch_reduction == "sum"
cham_x = cham_x.sum()
if return_normals:
cham_norm_x /= div
cham_norm_x = cham_norm_x.sum()
if batch_reduction == "mean":
div = weights.sum() if weights is not None else max(N, 1)
cham_x /= div
if return_normals:
cham_norm_x /= div

cham_dist = cham_x
cham_normals = cham_norm_x if return_normals else None
Expand All @@ -165,7 +169,7 @@ def chamfer_distance(
y_normals=None,
weights=None,
batch_reduction: Union[str, None] = "mean",
point_reduction: str = "mean",
point_reduction: Union[str, None] = "mean",
norm: int = 2,
single_directional: bool = False,
abs_cosine: bool = True,
Expand All @@ -191,7 +195,7 @@ def chamfer_distance(
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["mean", "sum"].
points, can be one of ["mean", "sum"] or None.
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
single_directional: If False (default), loss comes from both the distance between
each point in x and its nearest neighbor in y and each point in y and its nearest
Expand All @@ -206,10 +210,16 @@ def chamfer_distance(
2-element tuple containing

- **loss**: Tensor giving the reduced distance between the pointclouds
in x and the pointclouds in y.
in x and the pointclouds in y. If batch_reduction is None, a 2-element
haritha-j marked this conversation as resolved.
Show resolved Hide resolved
tuple of Tensors containing forward and backward loss terms shaped (N, P1) and (N, P2) (if
single_directional is False) or a Tensor containing loss terms shaped (N, P1) (if
single_directional is True) is returned.
- **loss_normals**: Tensor giving the reduced cosine distance of normals
between pointclouds in x and pointclouds in y. Returns None if
x_normals and y_normals are None.
x_normals and y_normals are None. If batch_reduction is None, a 2-element
haritha-j marked this conversation as resolved.
Show resolved Hide resolved
tuple of Tensors containing forward and backward loss terms shaped (N, P1) and (N, P2) (if
single_directional is False) or a Tensor containing loss terms shaped (N, P1) (if
single_directional is True) is returned.

"""
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
Expand Down Expand Up @@ -248,7 +258,12 @@ def chamfer_distance(
norm,
abs_cosine,
)
if point_reduction is not None:
return (
cham_x + cham_y,
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
)
return (
cham_x + cham_y,
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
(cham_x, cham_y),
(cham_norm_x, cham_norm_y) if cham_norm_x is not None else None,
)
214 changes: 177 additions & 37 deletions tests/test_chamfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,9 @@ def test_chamfer_pointcloud_object_withnormals(self):
("mean", "mean"),
("sum", None),
("mean", None),
(None, None),
]
for (point_reduction, batch_reduction) in reductions:

for point_reduction, batch_reduction in reductions:
# Reinitialize all the tensors so that the
# backward pass can be computed.
points_normals = TestChamfer.init_pointclouds(
Expand All @@ -450,24 +450,52 @@ def test_chamfer_pointcloud_object_withnormals(self):
batch_reduction=batch_reduction,
)

self.assertClose(cham_cloud, cham_tensor)
self.assertClose(norm_cloud, norm_tensor)
self._check_gradients(
cham_tensor,
norm_tensor,
cham_cloud,
norm_cloud,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
points_normals.p1_lengths,
points_normals.p2_lengths,
)
if point_reduction is None:
cham_tensor_bidirectional = torch.hstack(
[cham_tensor[0], cham_tensor[1]]
)
norm_tensor_bidirectional = torch.hstack(
[norm_tensor[0], norm_tensor[1]]
)
cham_cloud_bidirectional = torch.hstack([cham_cloud[0], cham_cloud[1]])
norm_cloud_bidirectional = torch.hstack([norm_cloud[0], norm_cloud[1]])
self.assertClose(cham_cloud_bidirectional, cham_tensor_bidirectional)
self.assertClose(norm_cloud_bidirectional, norm_tensor_bidirectional)
self._check_gradients(
cham_tensor_bidirectional,
norm_tensor_bidirectional,
cham_cloud_bidirectional,
norm_cloud_bidirectional,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
points_normals.p1_lengths,
points_normals.p2_lengths,
)
else:
self.assertClose(cham_cloud, cham_tensor)
self.assertClose(norm_cloud, norm_tensor)
self._check_gradients(
cham_tensor,
norm_tensor,
cham_cloud,
norm_cloud,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
points_normals.p1_lengths,
points_normals.p2_lengths,
)

def test_chamfer_pointcloud_object_nonormals(self):
N = 5
Expand All @@ -481,9 +509,9 @@ def test_chamfer_pointcloud_object_nonormals(self):
("mean", "mean"),
("sum", None),
("mean", None),
(None, None),
]
for (point_reduction, batch_reduction) in reductions:

for point_reduction, batch_reduction in reductions:
# Reinitialize all the tensors so that the
# backward pass can be computed.
points_normals = TestChamfer.init_pointclouds(
Expand All @@ -508,19 +536,38 @@ def test_chamfer_pointcloud_object_nonormals(self):
batch_reduction=batch_reduction,
)

self.assertClose(cham_cloud, cham_tensor)
self._check_gradients(
cham_tensor,
None,
cham_cloud,
None,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
lengths1=points_normals.p1_lengths,
lengths2=points_normals.p2_lengths,
)
if point_reduction is None:
cham_tensor_bidirectional = torch.hstack(
[cham_tensor[0], cham_tensor[1]]
)
cham_cloud_bidirectional = torch.hstack([cham_cloud[0], cham_cloud[1]])
self.assertClose(cham_cloud_bidirectional, cham_tensor_bidirectional)
self._check_gradients(
cham_tensor_bidirectional,
None,
cham_cloud_bidirectional,
None,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
lengths1=points_normals.p1_lengths,
lengths2=points_normals.p2_lengths,
)
else:
self.assertClose(cham_cloud, cham_tensor)
self._check_gradients(
cham_tensor,
None,
cham_cloud,
None,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
lengths1=points_normals.p1_lengths,
lengths2=points_normals.p2_lengths,
)

def test_chamfer_point_reduction_mean(self):
"""
Expand Down Expand Up @@ -707,6 +754,99 @@ def test_single_directional_chamfer_point_reduction_sum(self):
loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
)

def test_chamfer_point_reduction_none(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for point_reduction = None and batch_reduction = None.
"""
N, max_P1, max_P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
p1_normals = points_normals.n1
p2_normals = points_normals.n2
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True

pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=p1_normals, y_normals=p2_normals
)

# point_reduction = None
loss, loss_norm = chamfer_distance(
p11,
p22,
x_normals=p1_normals,
y_normals=p2_normals,
batch_reduction=None,
point_reduction=None,
)

loss_bidirectional = torch.hstack([loss[0], loss[1]])
pred_loss_bidirectional = torch.hstack([pred_loss[0], pred_loss[1]])
loss_norm_bidirectional = torch.hstack([loss_norm[0], loss_norm[1]])
pred_loss_norm_bidirectional = torch.hstack(
[pred_loss_norm[0], pred_loss_norm[1]]
)

self.assertClose(loss_bidirectional, pred_loss_bidirectional)
self.assertClose(loss_norm_bidirectional, pred_loss_norm_bidirectional)

# Check gradients
self._check_gradients(
loss_bidirectional,
loss_norm_bidirectional,
pred_loss_bidirectional,
pred_loss_norm_bidirectional,
p1,
p11,
p2,
p22,
)

def test_single_direction_chamfer_point_reduction_none(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for point_reduction = None and batch_reduction = None.
"""
N, max_P1, max_P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
p1_normals = points_normals.n1
p2_normals = points_normals.n2
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True

pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=p1_normals, y_normals=p2_normals
)

# point_reduction = None
loss, loss_norm = chamfer_distance(
p11,
p22,
x_normals=p1_normals,
y_normals=p2_normals,
batch_reduction=None,
point_reduction=None,
single_directional=True,
)

self.assertClose(loss, pred_loss[0])
self.assertClose(loss_norm, pred_loss_norm[0])

# Check gradients
self._check_gradients(
loss, loss_norm, pred_loss[0], pred_loss_norm[0], p1, p11, p2, p22
)

def _check_gradients(
self,
loss,
Expand Down Expand Up @@ -880,9 +1020,9 @@ def test_chamfer_joint_reduction(self):
with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"):
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")

# Error when point_reduction is not in ["mean", "sum"].
# Error when point_reduction is not in ["mean", "sum"] or None.
with self.assertRaisesRegex(ValueError, "point_reduction must be one of"):
chamfer_distance(p1, p2, weights=weights, point_reduction=None)
chamfer_distance(p1, p2, weights=weights, point_reduction="max")

def test_incorrect_weights(self):
N, P1, P2 = 16, 64, 128
Expand Down