Skip to content

Commit

Permalink
Deduplicate texture maps when joining
Browse files Browse the repository at this point in the history
Summary:
If you join several meshes which have TexturesUV textures using join_meshes_as_scene then we amalgamate all the texture images in to a single one. This now checks if some of the images are equal (i.e. the tensors are the same tensor, in the `is` sense; they have the same `id` in Python) and only uses one copy if they are.

I have an example of a massive scene made of several textured meshes with some shared, where this makes the difference between fitting the data on the GPU and not.

Reviewed By: theschnitz

Differential Revision: D25982364

fbshipit-source-id: a8228805f38475c796302e27328a340d9b56c8ef
  • Loading branch information
bottler authored and facebook-github-bot committed May 26, 2021
1 parent cd5af25 commit e12a081
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 42 deletions.
43 changes: 22 additions & 21 deletions pytorch3d/renderer/mesh/textures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list
from torch.nn.functional import interpolate

from .utils import pack_rectangles
from .utils import PackedRectangle, Rectangle, pack_unique_rectangles


# This file contains classes and helper functions for texturing.
Expand Down Expand Up @@ -1028,14 +1028,13 @@ def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV":
maps_list = []
faces_uvs_list += self.faces_uvs_list()
verts_uvs_list += self.verts_uvs_list()
maps_list += list(self.maps_padded().unbind(0))
maps_list += self.maps_list()
num_faces_per_mesh = self._num_faces_per_mesh
for tex in textures:
verts_uvs_list += tex.verts_uvs_list()
faces_uvs_list += tex.faces_uvs_list()
num_faces_per_mesh += tex._num_faces_per_mesh
tex_map_list = list(tex.maps_padded().unbind(0))
maps_list += tex_map_list
maps_list += tex.maps_list()

new_tex = self.__class__(
maps=maps_list,
Expand All @@ -1048,10 +1047,7 @@ def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV":
return new_tex

def _place_map_into_single_map(
self,
single_map: torch.Tensor,
map_: torch.Tensor,
location: Tuple[int, int, bool], # (x,y) and whether flipped
self, single_map: torch.Tensor, map_: torch.Tensor, location: PackedRectangle
) -> None:
"""
Copy map into a larger tensor single_map at the destination specified by location.
Expand All @@ -1064,11 +1060,11 @@ def _place_map_into_single_map(
map_: (H, W, 3) source data
location: where to place map
"""
do_flip = location[2]
do_flip = location.flipped
source = map_.transpose(0, 1) if do_flip else map_
border_width = 0 if self.align_corners else 1
lower_u = location[0] + border_width
lower_v = location[1] + border_width
lower_u = location.x + border_width
lower_v = location.y + border_width
upper_u = lower_u + source.shape[0]
upper_v = lower_v + source.shape[1]
single_map[lower_u:upper_u, lower_v:upper_v] = source
Expand Down Expand Up @@ -1102,28 +1098,33 @@ def join_scene(self) -> "TexturesUV":
If align_corners=False, we need to add an artificial border around
every map.
We use the function `pack_rectangles` to provide a layout for the
single map. _place_map_into_single_map is used to copy the maps
into the single map. The merging of verts_uvs and faces_uvs are
handled locally in this function.
We use the function `pack_unique_rectangles` to provide a layout for
the single map. This means that if self was created with a list of maps,
and to() has not been called, and there were two maps which were exactly
the same tensor object, then they will become the same data in the unified map.
_place_map_into_single_map is used to copy the maps into the single map.
The merging of verts_uvs and faces_uvs is handled locally in this function.
"""
maps = self.maps_list()
heights_and_widths = []
extra_border = 0 if self.align_corners else 2
for map_ in maps:
heights_and_widths.append(
(map_.shape[0] + extra_border, map_.shape[1] + extra_border)
Rectangle(
map_.shape[0] + extra_border, map_.shape[1] + extra_border, id(map_)
)
)
merging_plan = pack_rectangles(heights_and_widths)
merging_plan = pack_unique_rectangles(heights_and_widths)
# pyre-fixme[16]: `Tensor` has no attribute `new_zeros`.
single_map = maps[0].new_zeros((*merging_plan.total_size, 3))
verts_uvs = self.verts_uvs_list()
verts_uvs_merged = []

for map_, loc, uvs in zip(maps, merging_plan.locations, verts_uvs):
new_uvs = uvs.clone()
self._place_map_into_single_map(single_map, map_, loc)
do_flip = loc[2]
if loc.is_first:
self._place_map_into_single_map(single_map, map_, loc)
do_flip = loc.flipped
x_shape = map_.shape[1] if do_flip else map_.shape[0]
y_shape = map_.shape[0] if do_flip else map_.shape[1]

Expand Down Expand Up @@ -1164,9 +1165,9 @@ def join_scene(self) -> "TexturesUV":
denom_y = merging_plan.total_size[1] - one_if_align
scale_y = y_shape - one_if_align
new_uvs[:, 1] *= scale_x / denom_x
new_uvs[:, 1] += (loc[0] + one_if_not_align) / denom_x
new_uvs[:, 1] += (loc.x + one_if_not_align) / denom_x
new_uvs[:, 0] *= scale_y / denom_y
new_uvs[:, 0] += (loc[1] + one_if_not_align) / denom_y
new_uvs[:, 0] += (loc.y + one_if_not_align) / denom_y

verts_uvs_merged.append(new_uvs)

Expand Down
96 changes: 84 additions & 12 deletions pytorch3d/renderer/mesh/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,25 @@ def _interpolate_zbuf(

# ----------- Rectangle Packing -------------------- #


class Rectangle(NamedTuple):
xsize: int
ysize: int
identifier: int


class PackedRectangle(NamedTuple):
x: int
y: int
flipped: bool
is_first: bool


class PackedRectangles(NamedTuple):
total_size: Tuple[int, int]
locations: List[PackedRectangle]


# Note the order of members matters here because it determines the queue order.
# We want to place longer rectangles first.
class _UnplacedRectangle(NamedTuple):
Expand All @@ -74,7 +93,7 @@ class _UnplacedRectangle(NamedTuple):

def _try_place_rectangle(
rect: _UnplacedRectangle,
placed_so_far: List[Tuple[int, int, bool]],
placed_so_far: List[PackedRectangle],
occupied: List[Tuple[int, int]],
) -> bool:
"""
Expand Down Expand Up @@ -156,10 +175,11 @@ def _try_place_rectangle(
current_start_idx = idx
if currently_packed >= needed_height:
current_max_width = max(interval[0], current_max_width)
placed_so_far[rect.ind] = (
placed_so_far[rect.ind] = PackedRectangle(
current_max_width,
occupied[current_start_idx - 1][1],
rect.flipped,
True,
)
new_occupied = (
current_max_width + rect.size[0],
Expand All @@ -182,11 +202,6 @@ def _try_place_rectangle(
return False


class PackedRectangles(NamedTuple):
total_size: Tuple[int, int]
locations: List[Tuple[int, int, bool]] # (x,y) and whether flipped


def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
"""
Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating
Expand All @@ -200,7 +215,9 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
Returns:
total_size: size of total large rectangle
rectangles: location for each of the input rectangles
rectangles: location for each of the input rectangles.
This includes whether they are flipped.
The is_first field is always True.
"""

if len(sizes) < 2:
Expand All @@ -213,14 +230,14 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
else:
queue.append(_UnplacedRectangle((size[0], size[1]), i, False))
queue.sort()
placed_so_far = [(-1, -1, False)] * len(sizes)
placed_so_far = [PackedRectangle(-1, -1, False, False)] * len(sizes)

biggest = queue.pop()
total_width, current_height = biggest.size
placed_so_far[biggest.ind] = (0, 0, biggest.flipped)
placed_so_far[biggest.ind] = PackedRectangle(0, 0, biggest.flipped, True)

second = queue.pop()
placed_so_far[second.ind] = (0, current_height, second.flipped)
placed_so_far[second.ind] = PackedRectangle(0, current_height, second.flipped, True)
current_height += second.size[1]
occupied = [biggest.size, (second.size[0], current_height)]

Expand All @@ -236,8 +253,63 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:

# rect wasn't placed in the current bounding box,
# so we add extra space to fit it in.
placed_so_far[rect.ind] = (0, current_height, rect.flipped)
placed_so_far[rect.ind] = PackedRectangle(0, current_height, rect.flipped, True)
current_height += rect.size[1]
occupied.append((rect.size[0], current_height))

return PackedRectangles((total_width, current_height), placed_so_far)


def pack_unique_rectangles(rectangles: List[Rectangle]) -> PackedRectangles:
"""
Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating
a rectangle by 90 degrees) is allowed. Inputs are deduplicated by their
identifier.
This is a wrapper around pack_rectangles, where inputs come with an
identifier. In particular, it calls pack_rectangles for the deduplicated inputs,
then returns the values for all the inputs. The output for all rectangles with
the same identifier will be the same, except that only the first one will have
the is_first field True.
This is used to join several uv maps into a single scene, see
TexturesUV.join_scene.
Args:
rectangles: List of sizes of rectangles to pack
Returns:
total_size: size of total large rectangle
rectangles: location for each of the input rectangles.
This includes whether they are flipped.
The is_first field is true for the first rectangle
with each identifier.
"""

if len(rectangles) < 2:
raise ValueError("Cannot pack less than two boxes")

input_map = {}
input_indices: List[Tuple[int, bool]] = []
unique_input_sizes: List[Tuple[int, int]] = []
for rectangle in rectangles:
if rectangle.identifier not in input_map:
unique_index = len(unique_input_sizes)
unique_input_sizes.append((rectangle.xsize, rectangle.ysize))
input_map[rectangle.identifier] = unique_index
input_indices.append((unique_index, True))
else:
unique_index = input_map[rectangle.identifier]
input_indices.append((unique_index, False))

if len(unique_input_sizes) == 1:
first = [PackedRectangle(0, 0, False, True)]
rest = (len(rectangles) - 1) * [PackedRectangle(0, 0, False, False)]
return PackedRectangles(unique_input_sizes[0], first + rest)

total_size, unique_locations = pack_rectangles(unique_input_sizes)
full_locations = []
for input_index, first in input_indices:
full_locations.append(unique_locations[input_index]._replace(is_first=first))

return PackedRectangles(total_size, full_locations)
9 changes: 8 additions & 1 deletion tests/test_render_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,9 @@ def test_join_uvs(self):
verts_shifted2 = verts.clone()
verts_shifted2 *= 0.5
verts_shifted2[:, 1] -= 7
verts_shifted3 = verts.clone()
verts_shifted3 *= 0.5
verts_shifted3[:, 1] -= 700

[faces] = plain_torus.faces_list()
nocolor = torch.zeros((100, 100), device=device)
Expand Down Expand Up @@ -697,7 +700,11 @@ def test_join_uvs(self):
mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2)
mesh3 = Meshes(verts=[verts_shifted2], faces=[faces], textures=textures3)
mesh = join_meshes_as_scene([mesh1, mesh2, mesh3])
# mesh4 is like mesh1 but outside the field of view. It is here to test
# that having another texture with the same map doesn't produce
# two copies in the joined map.
mesh4 = Meshes(verts=[verts_shifted3], faces=[faces], textures=textures1)
mesh = join_meshes_as_scene([mesh1, mesh2, mesh3, mesh4])

output = renderer(mesh)[0, ..., :3].cpu()
output1 = renderer(mesh1)[0, ..., :3].cpu()
Expand Down
49 changes: 41 additions & 8 deletions tests/test_texturing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
TexturesUV,
TexturesVertex,
_list_to_padded_wrapper,
)
from pytorch3d.renderer.mesh.utils import (
Rectangle,
pack_rectangles,
pack_unique_rectangles,
)
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
from test_meshes import init_mesh
Expand Down Expand Up @@ -873,21 +877,24 @@ def wrap_pack(self, sizes):
mask = torch.zeros(total, dtype=torch.bool)
seen_x_bound = False
seen_y_bound = False
for (in_x, in_y), loc in zip(sizes, res.locations):
self.assertGreaterEqual(loc[0], 0)
self.assertGreaterEqual(loc[1], 0)
placed_x, placed_y = (in_y, in_x) if loc[2] else (in_x, in_y)
upper_x = placed_x + loc[0]
upper_y = placed_y + loc[1]
for (in_x, in_y), (out_x, out_y, flipped, is_first) in zip(
sizes, res.locations
):
self.assertTrue(is_first)
self.assertGreaterEqual(out_x, 0)
self.assertGreaterEqual(out_y, 0)
placed_x, placed_y = (in_y, in_x) if flipped else (in_x, in_y)
upper_x = placed_x + out_x
upper_y = placed_y + out_y
self.assertGreaterEqual(total[0], upper_x)
if total[0] == upper_x:
seen_x_bound = True
self.assertGreaterEqual(total[1], upper_y)
if total[1] == upper_y:
seen_y_bound = True
already_taken = torch.sum(mask[loc[0] : upper_x, loc[1] : upper_y])
already_taken = torch.sum(mask[out_x:upper_x, out_y:upper_y])
self.assertEqual(already_taken, 0)
mask[loc[0] : upper_x, loc[1] : upper_y] = 1
mask[out_x:upper_x, out_y:upper_y] = 1
self.assertTrue(seen_x_bound)
self.assertTrue(seen_y_bound)

Expand Down Expand Up @@ -930,3 +937,29 @@ def test_random(self):
for j in range(vals.shape[0]):
sizes.append((int(vals[j, 0]), int(vals[j, 1])))
self.wrap_pack(sizes)

def test_all_identical(self):
sizes = [Rectangle(xsize=61, ysize=82, identifier=1729)] * 3
total_size, locations = pack_unique_rectangles(sizes)
self.assertEqual(total_size, (61, 82))
self.assertEqual(len(locations), 3)
for i, (x, y, is_flipped, is_first) in enumerate(locations):
self.assertEqual(x, 0)
self.assertEqual(y, 0)
self.assertFalse(is_flipped)
self.assertEqual(is_first, i == 0)

def test_one_different_id(self):
sizes = [Rectangle(xsize=61, ysize=82, identifier=220)] * 3
sizes.extend([Rectangle(xsize=61, ysize=82, identifier=284)] * 3)
total_size, locations = pack_unique_rectangles(sizes)
self.assertEqual(total_size, (82, 122))
self.assertEqual(len(locations), 6)
for i, (x, y, is_flipped, is_first) in enumerate(locations):
self.assertTrue(is_flipped)
self.assertEqual(is_first, i % 3 == 0)
self.assertEqual(x, 0)
if i < 3:
self.assertEqual(y, 61)
else:
self.assertEqual(y, 0)

0 comments on commit e12a081

Please sign in to comment.