Skip to content

Commit

Permalink
Move Harmonic embedding to core pytorch3d
Browse files Browse the repository at this point in the history
Summary:
Moved `HarmonicEmbedding` function in core PyTorch3D.
In the next diff will update the NeRF project.

Reviewed By: bottler

Differential Revision: D32833808

fbshipit-source-id: 0a12ccd1627c0ce024463c796544c91eb8d4d122
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Dec 21, 2021
1 parent d67662d commit f9a26a2
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 1 deletion.
1 change: 1 addition & 0 deletions pytorch3d/renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
VolumeSampler,
ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points,
HarmonicEmbedding,
)
from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular
from .materials import Materials
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/renderer/implicit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .harmonic_embedding import HarmonicEmbedding
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
Expand All @@ -13,5 +14,4 @@
ray_bundle_variables_to_ray_points,
)


__all__ = [k for k in globals().keys() if not k.startswith("_")]
127 changes: 127 additions & 0 deletions pytorch3d/renderer/implicit/harmonic_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch


class HarmonicEmbedding(torch.nn.Module):
def __init__(
self,
n_harmonic_functions: int = 6,
omega_0: float = 1.0,
logspace: bool = True,
append_input: bool = True,
) -> None:
"""
Given an input tensor `x` of shape [minibatch, ... , dim],
the harmonic embedding layer converts each feature
(i.e. vector along the last dimension) in `x`
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]:
```
[
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i], # only present if append_input is True.
]
```
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
powers of 2:
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
If `logspace==False`, frequencies are linearly spaced between
`1.0` and `2**(n_harmonic_functions-1)`:
`f_1, ..., f_N = torch.linspace(
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
)`
Note that `x` is also premultiplied by the base frequency `omega_0`
before evaluating the harmonic functions.
Args:
n_harmonic_functions: int, number of harmonic
features
omega_0: float, base frequency
logspace: bool, Whether to space the frequencies in
logspace or linear space
append_input: bool, whether to concat the original
input to the harmonic embedding. If true the
output is of the form (x, embed.sin(), embed.cos()
"""
super().__init__()

if logspace:
frequencies = 2.0 ** torch.arange(
n_harmonic_functions,
dtype=torch.float32,
)
else:
frequencies = torch.linspace(
1.0,
2.0 ** (n_harmonic_functions - 1),
n_harmonic_functions,
dtype=torch.float32,
)

self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
self.append_input = append_input

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: tensor of shape [..., dim]
Returns:
embedding: a harmonic embedding of `x`
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
"""
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
embed = torch.cat(
(embed.sin(), embed.cos(), x)
if self.append_input
else (embed.sin(), embed.cos()),
dim=-1,
)
return embed

@staticmethod
def get_output_dim_static(
input_dims: int,
n_harmonic_functions: int,
append_input: bool,
) -> int:
"""
Utility to help predict the shape of the output of `forward`.
Args:
input_dims: length of the last dimension of the input tensor
n_harmonic_functions: number of embedding frequencies
append_input: whether or not to concat the original
input to the harmonic embedding
Returns:
int: the length of the last dimension of the output tensor
"""
return input_dims * (2 * n_harmonic_functions + int(append_input))

def get_output_dim(self, input_dims: int = 3) -> int:
"""
Same as above. The default for input_dims is 3 for 3D applications
which use harmonic embedding for positional encoding,
so the input might be xyz.
"""
return self.get_output_dim_static(
input_dims, len(self._frequencies), self.append_input
)
50 changes: 50 additions & 0 deletions tests/test_harmonic_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.implicit import HarmonicEmbedding


class TestHarmonicEmbedding(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)

def test_correct_output_dim(self):
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False)
# input_dims * (2 * n_harmonic_functions + int(append_input))
output_dim = 3 * (2 * 2 + int(False))
self.assertEqual(
output_dim,
embed_fun.get_output_dim_static(
input_dims=3, n_harmonic_functions=2, append_input=False
),
)
self.assertEqual(output_dim, embed_fun.get_output_dim())

def test_correct_frequency_range(self):
embed_fun_log = HarmonicEmbedding(n_harmonic_functions=3)
embed_fun_lin = HarmonicEmbedding(n_harmonic_functions=3, logspace=False)
self.assertClose(embed_fun_log._frequencies, torch.FloatTensor((1.0, 2.0, 4.0)))
self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0)))

def test_correct_embed_out(self):
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False)
x = torch.randn((1, 5))
D = 5 * 4
embed_out = embed_fun(x)
self.assertEqual(embed_out.shape, (1, D))
# Sum the squares of the respective frequencies
sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2
self.assertClose(sum_squares, torch.ones((D // 2)))
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=True)
embed_out = embed_fun(x)
self.assertClose(embed_out.shape, torch.tensor((1, 5 * 5)))
# Last plane in output is the input
self.assertClose(embed_out[..., -5:], x)

0 comments on commit f9a26a2

Please sign in to comment.