-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move Harmonic embedding to core pytorch3d
Summary: Moved `HarmonicEmbedding` function in core PyTorch3D. In the next diff will update the NeRF project. Reviewed By: bottler Differential Revision: D32833808 fbshipit-source-id: 0a12ccd1627c0ce024463c796544c91eb8d4d122
- Loading branch information
1 parent
d67662d
commit f9a26a2
Showing
4 changed files
with
179 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
|
||
|
||
class HarmonicEmbedding(torch.nn.Module): | ||
def __init__( | ||
self, | ||
n_harmonic_functions: int = 6, | ||
omega_0: float = 1.0, | ||
logspace: bool = True, | ||
append_input: bool = True, | ||
) -> None: | ||
""" | ||
Given an input tensor `x` of shape [minibatch, ... , dim], | ||
the harmonic embedding layer converts each feature | ||
(i.e. vector along the last dimension) in `x` | ||
into a series of harmonic features `embedding`, | ||
where for each i in range(dim) the following are present | ||
in embedding[...]: | ||
``` | ||
[ | ||
sin(f_1*x[..., i]), | ||
sin(f_2*x[..., i]), | ||
... | ||
sin(f_N * x[..., i]), | ||
cos(f_1*x[..., i]), | ||
cos(f_2*x[..., i]), | ||
... | ||
cos(f_N * x[..., i]), | ||
x[..., i], # only present if append_input is True. | ||
] | ||
``` | ||
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar | ||
denoting the i-th frequency of the harmonic embedding. | ||
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are | ||
powers of 2: | ||
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` | ||
If `logspace==False`, frequencies are linearly spaced between | ||
`1.0` and `2**(n_harmonic_functions-1)`: | ||
`f_1, ..., f_N = torch.linspace( | ||
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions | ||
)` | ||
Note that `x` is also premultiplied by the base frequency `omega_0` | ||
before evaluating the harmonic functions. | ||
Args: | ||
n_harmonic_functions: int, number of harmonic | ||
features | ||
omega_0: float, base frequency | ||
logspace: bool, Whether to space the frequencies in | ||
logspace or linear space | ||
append_input: bool, whether to concat the original | ||
input to the harmonic embedding. If true the | ||
output is of the form (x, embed.sin(), embed.cos() | ||
""" | ||
super().__init__() | ||
|
||
if logspace: | ||
frequencies = 2.0 ** torch.arange( | ||
n_harmonic_functions, | ||
dtype=torch.float32, | ||
) | ||
else: | ||
frequencies = torch.linspace( | ||
1.0, | ||
2.0 ** (n_harmonic_functions - 1), | ||
n_harmonic_functions, | ||
dtype=torch.float32, | ||
) | ||
|
||
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) | ||
self.append_input = append_input | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Args: | ||
x: tensor of shape [..., dim] | ||
Returns: | ||
embedding: a harmonic embedding of `x` | ||
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim] | ||
""" | ||
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1) | ||
embed = torch.cat( | ||
(embed.sin(), embed.cos(), x) | ||
if self.append_input | ||
else (embed.sin(), embed.cos()), | ||
dim=-1, | ||
) | ||
return embed | ||
|
||
@staticmethod | ||
def get_output_dim_static( | ||
input_dims: int, | ||
n_harmonic_functions: int, | ||
append_input: bool, | ||
) -> int: | ||
""" | ||
Utility to help predict the shape of the output of `forward`. | ||
Args: | ||
input_dims: length of the last dimension of the input tensor | ||
n_harmonic_functions: number of embedding frequencies | ||
append_input: whether or not to concat the original | ||
input to the harmonic embedding | ||
Returns: | ||
int: the length of the last dimension of the output tensor | ||
""" | ||
return input_dims * (2 * n_harmonic_functions + int(append_input)) | ||
|
||
def get_output_dim(self, input_dims: int = 3) -> int: | ||
""" | ||
Same as above. The default for input_dims is 3 for 3D applications | ||
which use harmonic embedding for positional encoding, | ||
so the input might be xyz. | ||
""" | ||
return self.get_output_dim_static( | ||
input_dims, len(self._frequencies), self.append_input | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
from common_testing import TestCaseMixin | ||
from pytorch3d.renderer.implicit import HarmonicEmbedding | ||
|
||
|
||
class TestHarmonicEmbedding(TestCaseMixin, unittest.TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
torch.manual_seed(1) | ||
|
||
def test_correct_output_dim(self): | ||
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False) | ||
# input_dims * (2 * n_harmonic_functions + int(append_input)) | ||
output_dim = 3 * (2 * 2 + int(False)) | ||
self.assertEqual( | ||
output_dim, | ||
embed_fun.get_output_dim_static( | ||
input_dims=3, n_harmonic_functions=2, append_input=False | ||
), | ||
) | ||
self.assertEqual(output_dim, embed_fun.get_output_dim()) | ||
|
||
def test_correct_frequency_range(self): | ||
embed_fun_log = HarmonicEmbedding(n_harmonic_functions=3) | ||
embed_fun_lin = HarmonicEmbedding(n_harmonic_functions=3, logspace=False) | ||
self.assertClose(embed_fun_log._frequencies, torch.FloatTensor((1.0, 2.0, 4.0))) | ||
self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0))) | ||
|
||
def test_correct_embed_out(self): | ||
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False) | ||
x = torch.randn((1, 5)) | ||
D = 5 * 4 | ||
embed_out = embed_fun(x) | ||
self.assertEqual(embed_out.shape, (1, D)) | ||
# Sum the squares of the respective frequencies | ||
sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2 | ||
self.assertClose(sum_squares, torch.ones((D // 2))) | ||
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=True) | ||
embed_out = embed_fun(x) | ||
self.assertClose(embed_out.shape, torch.tensor((1, 5 * 5))) | ||
# Last plane in output is the input | ||
self.assertClose(embed_out[..., -5:], x) |