From f9a26a22fcd6f7f16eae7dc8fd6e48ecadd7486b Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Tue, 21 Dec 2021 15:03:33 -0800 Subject: [PATCH] Move Harmonic embedding to core pytorch3d Summary: Moved `HarmonicEmbedding` function in core PyTorch3D. In the next diff will update the NeRF project. Reviewed By: bottler Differential Revision: D32833808 fbshipit-source-id: 0a12ccd1627c0ce024463c796544c91eb8d4d122 --- pytorch3d/renderer/__init__.py | 1 + pytorch3d/renderer/implicit/__init__.py | 2 +- .../renderer/implicit/harmonic_embedding.py | 127 ++++++++++++++++++ tests/test_harmonic_embedding.py | 50 +++++++ 4 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 pytorch3d/renderer/implicit/harmonic_embedding.py create mode 100644 tests/test_harmonic_embedding.py diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index e7ea97763..fe9ccdf10 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -37,6 +37,7 @@ VolumeSampler, ray_bundle_to_ray_points, ray_bundle_variables_to_ray_points, + HarmonicEmbedding, ) from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular from .materials import Materials diff --git a/pytorch3d/renderer/implicit/__init__.py b/pytorch3d/renderer/implicit/__init__.py index 248898cea..66a59ae50 100644 --- a/pytorch3d/renderer/implicit/__init__.py +++ b/pytorch3d/renderer/implicit/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from .harmonic_embedding import HarmonicEmbedding from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler @@ -13,5 +14,4 @@ ray_bundle_variables_to_ray_points, ) - __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/renderer/implicit/harmonic_embedding.py b/pytorch3d/renderer/implicit/harmonic_embedding.py new file mode 100644 index 000000000..d07e7eb9a --- /dev/null +++ b/pytorch3d/renderer/implicit/harmonic_embedding.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +class HarmonicEmbedding(torch.nn.Module): + def __init__( + self, + n_harmonic_functions: int = 6, + omega_0: float = 1.0, + logspace: bool = True, + append_input: bool = True, + ) -> None: + """ + Given an input tensor `x` of shape [minibatch, ... , dim], + the harmonic embedding layer converts each feature + (i.e. vector along the last dimension) in `x` + into a series of harmonic features `embedding`, + where for each i in range(dim) the following are present + in embedding[...]: + ``` + [ + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i], # only present if append_input is True. + ] + ``` + where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar + denoting the i-th frequency of the harmonic embedding. + + If `logspace==True`, the frequencies `[f_1, ..., f_N]` are + powers of 2: + `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` + + If `logspace==False`, frequencies are linearly spaced between + `1.0` and `2**(n_harmonic_functions-1)`: + `f_1, ..., f_N = torch.linspace( + 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions + )` + + Note that `x` is also premultiplied by the base frequency `omega_0` + before evaluating the harmonic functions. + + Args: + n_harmonic_functions: int, number of harmonic + features + omega_0: float, base frequency + logspace: bool, Whether to space the frequencies in + logspace or linear space + append_input: bool, whether to concat the original + input to the harmonic embedding. If true the + output is of the form (x, embed.sin(), embed.cos() + + """ + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + n_harmonic_functions, + dtype=torch.float32, + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (n_harmonic_functions - 1), + n_harmonic_functions, + dtype=torch.float32, + ) + + self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) + self.append_input = append_input + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: tensor of shape [..., dim] + Returns: + embedding: a harmonic embedding of `x` + of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim] + """ + embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1) + embed = torch.cat( + (embed.sin(), embed.cos(), x) + if self.append_input + else (embed.sin(), embed.cos()), + dim=-1, + ) + return embed + + @staticmethod + def get_output_dim_static( + input_dims: int, + n_harmonic_functions: int, + append_input: bool, + ) -> int: + """ + Utility to help predict the shape of the output of `forward`. + + Args: + input_dims: length of the last dimension of the input tensor + n_harmonic_functions: number of embedding frequencies + append_input: whether or not to concat the original + input to the harmonic embedding + Returns: + int: the length of the last dimension of the output tensor + """ + return input_dims * (2 * n_harmonic_functions + int(append_input)) + + def get_output_dim(self, input_dims: int = 3) -> int: + """ + Same as above. The default for input_dims is 3 for 3D applications + which use harmonic embedding for positional encoding, + so the input might be xyz. + """ + return self.get_output_dim_static( + input_dims, len(self._frequencies), self.append_input + ) diff --git a/tests/test_harmonic_embedding.py b/tests/test_harmonic_embedding.py new file mode 100644 index 000000000..2656936d3 --- /dev/null +++ b/tests/test_harmonic_embedding.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from common_testing import TestCaseMixin +from pytorch3d.renderer.implicit import HarmonicEmbedding + + +class TestHarmonicEmbedding(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(1) + + def test_correct_output_dim(self): + embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False) + # input_dims * (2 * n_harmonic_functions + int(append_input)) + output_dim = 3 * (2 * 2 + int(False)) + self.assertEqual( + output_dim, + embed_fun.get_output_dim_static( + input_dims=3, n_harmonic_functions=2, append_input=False + ), + ) + self.assertEqual(output_dim, embed_fun.get_output_dim()) + + def test_correct_frequency_range(self): + embed_fun_log = HarmonicEmbedding(n_harmonic_functions=3) + embed_fun_lin = HarmonicEmbedding(n_harmonic_functions=3, logspace=False) + self.assertClose(embed_fun_log._frequencies, torch.FloatTensor((1.0, 2.0, 4.0))) + self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0))) + + def test_correct_embed_out(self): + embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False) + x = torch.randn((1, 5)) + D = 5 * 4 + embed_out = embed_fun(x) + self.assertEqual(embed_out.shape, (1, D)) + # Sum the squares of the respective frequencies + sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2 + self.assertClose(sum_squares, torch.ones((D // 2))) + embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=True) + embed_out = embed_fun(x) + self.assertClose(embed_out.shape, torch.tensor((1, 5 * 5))) + # Last plane in output is the input + self.assertClose(embed_out[..., -5:], x)