diff --git a/pytorch3d/implicitron/models/renderer/multipass_ea.py b/pytorch3d/implicitron/models/renderer/multipass_ea.py index 03511c295..92042e131 100644 --- a/pytorch3d/implicitron/models/renderer/multipass_ea.py +++ b/pytorch3d/implicitron/models/renderer/multipass_ea.py @@ -157,9 +157,13 @@ def _run_raymarcher( else 0.0 ) + ray_deltas = ( + None if ray_bundle.bins is None else torch.diff(ray_bundle.bins, dim=-1) + ) output = self.raymarcher( *implicit_functions[0](ray_bundle=ray_bundle), ray_lengths=ray_bundle.lengths, + ray_deltas=ray_deltas, density_noise_std=density_noise_std, ) output.prev_stage = prev_stage diff --git a/pytorch3d/implicitron/models/renderer/ray_point_refiner.py b/pytorch3d/implicitron/models/renderer/ray_point_refiner.py index ffdfe2a9e..0266f939d 100644 --- a/pytorch3d/implicitron/models/renderer/ray_point_refiner.py +++ b/pytorch3d/implicitron/models/renderer/ray_point_refiner.py @@ -78,19 +78,28 @@ def forward( """ - z_vals = input_ray_bundle.lengths with torch.no_grad(): if self.blurpool_weights: ray_weights = apply_blurpool_on_weights(ray_weights) - z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5) + n_pts_per_ray = self.n_pts_per_ray + ray_weights = ray_weights.view(-1, ray_weights.shape[-1]) + if input_ray_bundle.bins is None: + z_vals: torch.Tensor = input_ray_bundle.lengths + ray_weights = ray_weights[..., 1:-1] + bins = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5) + else: + z_vals = input_ray_bundle.bins + n_pts_per_ray += 1 + bins = z_vals z_samples = sample_pdf( - z_vals_mid.view(-1, z_vals_mid.shape[-1]), - ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1], - self.n_pts_per_ray, + bins.view(-1, bins.shape[-1]), + ray_weights, + n_pts_per_ray, det=not self.random_sampling, eps=self.sample_pdf_eps, - ).view(*z_vals.shape[:-1], self.n_pts_per_ray) + ).view(*z_vals.shape[:-1], n_pts_per_ray) + if self.add_input_samples: z_vals = torch.cat((z_vals, z_samples), dim=-1) else: @@ -98,9 +107,13 @@ def forward( # Resort by depth. z_vals, _ = torch.sort(z_vals, dim=-1) - new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle)) - new_bundle.lengths = z_vals - return new_bundle + kwargs_ray = dict(vars(input_ray_bundle)) + if input_ray_bundle.bins is None: + kwargs_ray["lengths"] = z_vals + return ImplicitronRayBundle(**kwargs_ray) + kwargs_ray["bins"] = z_vals + del kwargs_ray["lengths"] + return ImplicitronRayBundle.from_bins(**kwargs_ray) def apply_blurpool_on_weights(weights) -> torch.Tensor: diff --git a/pytorch3d/implicitron/models/renderer/raymarcher.py b/pytorch3d/implicitron/models/renderer/raymarcher.py index 1e4ec1e64..9c6addf1a 100644 --- a/pytorch3d/implicitron/models/renderer/raymarcher.py +++ b/pytorch3d/implicitron/models/renderer/raymarcher.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import torch from pytorch3d.implicitron.models.renderer.base import RendererOutput @@ -119,6 +119,7 @@ def forward( rays_features: torch.Tensor, aux: Dict[str, Any], ray_lengths: torch.Tensor, + ray_deltas: Optional[torch.Tensor] = None, density_noise_std: float = 0.0, **kwargs, ) -> RendererOutput: @@ -131,6 +132,9 @@ def forward( aux: a dictionary with extra information. ray_lengths: Per-ray depth values represented with a tensor of shape `(..., n_points_per_ray, feature_dim)`. + ray_deltas: Optional differences between consecutive elements along the ray bundle + represented with a tensor of shape `(..., n_points_per_ray)`. If None, + these differences are computed from ray_lengths. density_noise_std: the magnitude of the noise added to densities. Returns: @@ -152,14 +156,17 @@ def forward( density_1d=True, ) - ray_lengths_diffs = ray_lengths[..., 1:] - ray_lengths[..., :-1] - if self.replicate_last_interval: - last_interval = ray_lengths_diffs[..., -1:] + if ray_deltas is None: + ray_lengths_diffs = torch.diff(ray_lengths, dim=-1) + if self.replicate_last_interval: + last_interval = ray_lengths_diffs[..., -1:] + else: + last_interval = torch.full_like( + ray_lengths[..., :1], self.background_opacity + ) + deltas = torch.cat((ray_lengths_diffs, last_interval), dim=-1) else: - last_interval = torch.full_like( - ray_lengths[..., :1], self.background_opacity - ) - deltas = torch.cat((ray_lengths_diffs, last_interval), dim=-1) + deltas = ray_deltas rays_densities = rays_densities[..., 0] diff --git a/pytorch3d/renderer/implicit/harmonic_embedding.py b/pytorch3d/renderer/implicit/harmonic_embedding.py index 418eaa73e..90e857f8a 100644 --- a/pytorch3d/renderer/implicit/harmonic_embedding.py +++ b/pytorch3d/renderer/implicit/harmonic_embedding.py @@ -24,7 +24,7 @@ def __init__( and the integrated position encoding in `MIP-NeRF `_. - During, the inference you can provide the extra argument `diag_cov`. + During the inference you can provide the extra argument `diag_cov`. If `diag_cov is None`, it converts rays parametrized with a `ray_bundle` to 3D points by diff --git a/tests/implicitron/test_ray_point_refiner.py b/tests/implicitron/test_ray_point_refiner.py index c4e7b2208..8a6ab3612 100644 --- a/tests/implicitron/test_ray_point_refiner.py +++ b/tests/implicitron/test_ray_point_refiner.py @@ -70,6 +70,71 @@ def test_simple(self): (lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all() ) + def test_simple_use_bins(self): + """ + Same spirit than test_simple but use bins in the ImplicitronRayBunle. + It has been duplicated to avoid cognitive overload while reading the + test (lot of if else). + """ + length = 15 + n_pts_per_ray = 10 + + for add_input_samples, use_blurpool in product([False, True], [False, True]): + ray_point_refiner = RayPointRefiner( + n_pts_per_ray=n_pts_per_ray, + random_sampling=False, + add_input_samples=add_input_samples, + ) + + bundle = ImplicitronRayBundle( + lengths=None, + bins=torch.arange(length + 1, dtype=torch.float32).expand( + 3, 25, length + 1 + ), + origins=None, + directions=None, + xys=None, + camera_ids=None, + camera_counts=None, + ) + weights = torch.ones(3, 25, length) + refined = ray_point_refiner(bundle, weights, blurpool_weights=use_blurpool) + + self.assertIsNone(refined.directions) + self.assertIsNone(refined.origins) + self.assertIsNone(refined.xys) + expected_bins = torch.linspace(0, length, n_pts_per_ray + 1) + expected_bins = expected_bins.expand(3, 25, n_pts_per_ray + 1) + if add_input_samples: + expected_bins = torch.cat((bundle.bins, expected_bins), dim=-1).sort()[ + 0 + ] + full_expected = torch.lerp( + expected_bins[..., :-1], expected_bins[..., 1:], 0.5 + ) + + self.assertClose(refined.lengths, full_expected) + + ray_point_refiner_random = RayPointRefiner( + n_pts_per_ray=n_pts_per_ray, + random_sampling=True, + add_input_samples=add_input_samples, + ) + + refined_random = ray_point_refiner_random( + bundle, weights, blurpool_weights=use_blurpool + ) + lengths_random = refined_random.lengths + self.assertEqual(lengths_random.shape, full_expected.shape) + if not add_input_samples: + self.assertGreater(lengths_random.min().item(), 0) + self.assertLess(lengths_random.max().item(), length) + + # Check sorted + self.assertTrue( + (lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all() + ) + def test_apply_blurpool_on_weights(self): weights = torch.tensor( [