Skip to content

Commit

Permalink
Adding the option to choose the texture sampling mode in TexturesUV.
Browse files Browse the repository at this point in the history
Summary:
This diff adds the `sample_mode` parameter to `TexturesUV` to control the interpolation mode during texture sampling. It simply gets forwarded to `torch.nn.funcitonal.grid_sample`.

This option was requested in this [GitHub issue](#805).

Reviewed By: patricklabatut

Differential Revision: D32665185

fbshipit-source-id: ac0bc66a018bd4cb20d75fec2d7c11145dd20199
  • Loading branch information
anadodik authored and facebook-github-bot committed Nov 29, 2021
1 parent e4456db commit d9f7095
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions pytorch3d/renderer/mesh/textures.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ def __init__(
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
padding_mode: str = "border",
align_corners: bool = True,
sampling_mode: str = "bilinear",
) -> None:
"""
Textures are represented as a per mesh texture map and uv coordinates for each
Expand All @@ -613,6 +614,9 @@ def __init__(
indicate the centers of the edge pixels in the maps.
padding_mode: padding mode for outside grid values
("zeros", "border" or "reflection").
sampling_mode: type of interpolation used to sample the texture.
Corresponds to the mode parameter in PyTorch's
grid_sample ("nearest" or "bilinear").
The align_corners and padding_mode arguments correspond to the arguments
of the `grid_sample` torch function. There is an informative illustration of
Expand Down Expand Up @@ -641,6 +645,7 @@ def __init__(
"""
self.padding_mode = padding_mode
self.align_corners = align_corners
self.sampling_mode = sampling_mode
if isinstance(faces_uvs, (list, tuple)):
for fv in faces_uvs:
if fv.ndim != 2 or fv.shape[-1] != 3:
Expand Down Expand Up @@ -749,6 +754,9 @@ def clone(self) -> "TexturesUV":
self.maps_padded().clone(),
self.faces_uvs_padded().clone(),
self.verts_uvs_padded().clone(),
align_corners=self.align_corners,
padding_mode=self.padding_mode,
sampling_mode=self.sampling_mode,
)
if self._maps_list is not None:
tex._maps_list = [m.clone() for m in self._maps_list]
Expand All @@ -770,6 +778,9 @@ def detach(self) -> "TexturesUV":
self.maps_padded().detach(),
self.faces_uvs_padded().detach(),
self.verts_uvs_padded().detach(),
align_corners=self.align_corners,
padding_mode=self.padding_mode,
sampling_mode=self.sampling_mode,
)
if self._maps_list is not None:
tex._maps_list = [m.detach() for m in self._maps_list]
Expand Down Expand Up @@ -801,6 +812,7 @@ def __getitem__(self, index) -> "TexturesUV":
maps=maps,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
)
elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
new_tex = self.__class__(
Expand All @@ -809,6 +821,7 @@ def __getitem__(self, index) -> "TexturesUV":
maps=[maps],
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
)
else:
raise ValueError("Not all values are provided in the correct format")
Expand Down Expand Up @@ -889,6 +902,7 @@ def extend(self, N: int) -> "TexturesUV":
verts_uvs=new_props["verts_uvs_padded"],
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
)

new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
Expand Down Expand Up @@ -966,6 +980,7 @@ def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
texels = F.grid_sample(
texture_maps,
pixel_uvs,
mode=self.sampling_mode,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
)
Expand Down Expand Up @@ -1003,6 +1018,7 @@ def faces_verts_textures_packed(self) -> torch.Tensor:
textures = F.grid_sample(
texture_maps,
faces_verts_uvs,
mode=self.sampling_mode,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
) # NxCxmax(Fi)x3
Expand Down Expand Up @@ -1060,6 +1076,7 @@ def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV":
faces_uvs=faces_uvs_list,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
)
new_tex._num_faces_per_mesh = num_faces_per_mesh
return new_tex
Expand Down Expand Up @@ -1227,6 +1244,7 @@ def join_scene(self) -> "TexturesUV":
faces_uvs=[torch.cat(faces_uvs_merged)],
align_corners=self.align_corners,
padding_mode=self.padding_mode,
sampling_mode=self.sampling_mode,
)

def centers_for_image(self, index: int) -> torch.Tensor:
Expand Down Expand Up @@ -1259,6 +1277,7 @@ def centers_for_image(self, index: int) -> torch.Tensor:
torch.flip(coords.to(texture_image), [2]),
# Convert from [0, 1] -> [-1, 1] range expected by grid sample
verts_uvs[:, None] * 2.0 - 1,
mode=self.sampling_mode,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
).cpu()
Expand Down

0 comments on commit d9f7095

Please sign in to comment.