Skip to content

Commit

Permalink
Multimodal collater with interleaved image, cross-attention mask padd…
Browse files Browse the repository at this point in the history
…ing (#1156)

Rafi Ayub <[email protected]>
  • Loading branch information
RdoubleA authored Sep 11, 2024
1 parent b9f2caf commit 377abc0
Show file tree
Hide file tree
Showing 4 changed files with 303 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/api_ref_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Collaters used to collect samples into batches and handle any padding.
:nosignatures:

padded_collate
padded_collate_tiled_images_and_mask
padded_collate_sft
padded_collate_dpo
left_pad_sequence
Expand Down
107 changes: 107 additions & 0 deletions tests/torchtune/data/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
padded_collate_dpo,
padded_collate_packed,
padded_collate_sft,
padded_collate_tiled_images_and_mask,
)
from torchtune.modules.attention_utils import _SUPPORTS_FLEX_ATTENTION

Expand Down Expand Up @@ -52,6 +53,112 @@ def test_batch_pad_sequence(self):
padded_label, torch.tensor([10, ignore_idx, ignore_idx])
)


class TestPaddedCollateTiledImagesAndMask:
@pytest.fixture
def batch(self):
return [
{
"tokens": [1, 2, 1, 3],
"labels": [4, 5, 6, 7],
"encoder_input": {
"images": [torch.ones(2, 1, 1, 1), torch.ones(3, 1, 1, 1)],
"aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])],
},
"encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)],
},
{
"tokens": [1, 4],
"labels": [8, 9],
"encoder_input": {
"images": [torch.ones(4, 1, 1, 1)],
"aspect_ratio": [torch.tensor([2, 2])],
},
"encoder_mask": [torch.ones(2, 5 * 4)],
},
]

def test_right_pad_sequence(self, batch):
actual = padded_collate_tiled_images_and_mask(
batch=batch, padding_idx=0, ignore_idx=-100, pad_direction="right"
)

mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1)
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1)
mask_3 = torch.concat([torch.ones(2, 5 * 4), torch.zeros(2, 20)], dim=0)
sample_1 = torch.stack([mask_1, mask_2])
sample_2 = torch.stack([mask_3, torch.zeros(4, 20)])
expected_mask = torch.stack([sample_1, sample_2]).view(2, 4, -1)

expected = {
"tokens": torch.tensor([[1, 2, 1, 3], [1, 4, 0, 0]]),
"labels": torch.tensor([[4, 5, 6, 7], [8, 9, -100, -100]]),
"encoder_input": {
"images": torch.tensor(
[
[
[[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]],
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]],
],
[
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]]],
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
],
]
),
"aspect_ratio": torch.tensor([[[1, 2], [1, 3]], [[2, 2], [1, 1]]]),
},
"encoder_mask": expected_mask,
}

for k in expected:
if isinstance(expected[k], dict):
for k1 in expected[k]:
torch.testing.assert_close(actual[k][k1], expected[k][k1])
else:
torch.testing.assert_close(actual[k], expected[k])

def test_left_pad_sequence(self, batch):
actual = padded_collate_tiled_images_and_mask(
batch=batch, padding_idx=0, ignore_idx=-100, pad_direction="left"
)

mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1)
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1)
mask_3 = torch.concat([torch.ones(2, 5 * 4), torch.zeros(2, 20)], dim=0)
sample_1 = torch.stack([mask_1, mask_2])
sample_2 = torch.stack([mask_3, torch.zeros(4, 20)])
expected_mask = torch.stack([sample_1, sample_2]).view(2, 4, -1)

expected = {
"tokens": torch.tensor([[1, 2, 1, 3], [0, 0, 1, 4]]),
"encoder_input": {
"images": torch.tensor(
[
[
[[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]],
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]],
],
[
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]]],
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
],
]
),
"aspect_ratio": torch.tensor([[[1, 2], [1, 3]], [[2, 2], [1, 1]]]),
},
"encoder_mask": expected_mask,
}

for k in expected:
if isinstance(expected[k], dict):
for k1 in expected[k]:
torch.testing.assert_close(actual[k][k1], expected[k][k1])
else:
torch.testing.assert_close(actual[k], expected[k])


class TestPaddedCollatePacked:
@mock.patch("torchtune.modules.attention_utils._SUPPORTS_FLEX_ATTENTION", False)
def test_padded_collate_packed_sdpa(self):
token_pairs = [
Expand Down
2 changes: 2 additions & 0 deletions torchtune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
padded_collate_dpo,
padded_collate_packed,
padded_collate_sft,
padded_collate_tiled_images_and_mask,
)
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._converters import get_openai_messages, get_sharegpt_messages
Expand Down Expand Up @@ -59,6 +60,7 @@
"padded_collate_dpo",
"left_pad_sequence",
"padded_collate",
"padded_collate_tiled_images_and_mask",
"padded_collate_packed",
"load_image",
]
195 changes: 193 additions & 2 deletions torchtune/data/_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -85,7 +85,7 @@ def padded_collate(
torch.Tensor: The padded tensor of input ids with shape [batch_size, max_seq_len].
Raises:
ValueError: if ``pad_direction`` is not one of "left" or "right.
ValueError: if ``pad_direction`` is not one of "left" or "right".
ValueError: if ``keys_to_pad`` is empty, or is not a list, or is not a subset of keys in the batch.
ValueError: if ``padding_idx`` is provided as a dictionary, but the keys are not identical to
``keys_to_pad``.
Expand Down Expand Up @@ -216,6 +216,197 @@ def padded_collate_sft(
return {"tokens": input_ids.long(), "labels": labels.long()}


# TODO: Generalize this to support any type of encoder input, right now this assumes
# a specific encoder_input signature
def padded_collate_tiled_images_and_mask(
batch: List[Dict[str, Any]],
padding_idx: int = 0,
ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX,
pad_direction: str = "right",
) -> Dict[str, torch.Tensor]:
"""Pad a batch of text sequences, tiled image tensors, aspect ratios,
and cross attention masks. This can be used for both training and inference.
``batch`` is expected to be a list of sample dicts containing the following:
- "tokens": List[int] of length text_seq_len, varies across samples
- "labels": List[int] of length text_seq_len, varies across samples
- "encoder_input": Dict[str, List[torch.Tensor]]
- "images": List[torch.Tensor], each with shape (n_tiles, c, h, w)
- "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio
- "encoder_mask": List[Tensor], each with shape (text_seq_len, image_seq_len)
where c = channel dim, h = height dim, w = weight dim. For each element in the batch,
len(images) == len(encoder_mask) == len(aspect_ratio).
This collater does the following:
(1) Pad text sequence and encoder mask to the longest sequence length in the batch
(2) Pad image tensors in the tile dimension with zeros to the largest number
of tiles in the batch
(3) Add empty images of zeros to samples up to max number of images in the batch
(4) Pad aspect ratios with (1,1) for all added padding images
Args:
batch (List[Dict[str, Any]]): A list of sample dicts containing tokens,
labels, images, encoder_mask, and aspect_ratio.
padding_idx (int): Padding index for input token ids. Defaults to 0.
ignore_idx (int): Padding index for labels. Defaults to -100.
pad_direction (str): whether to pad entries from the left, or right. If ``pad_direction="right"``, we use
:func:`torch.nn.utils.rnn.pad_sequence`, otherwise if ``pad_direction="left"``,
we use :func:`torchtune.data.left_pad_sequence`. For training, we typically want to pad from the right.
For inference, we typically want to pad from the left. Defaults to "right".
Returns:
Dict[str, Tensor]: Collated tokens, labels, images, encoder_mask, aspect_ratio tensors.
- tokens: Tensor of shape (bsz, max_seq_len)
- labels: Tensor of shape (bsz, max_seq_len)
- images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w)
- encoder_mask: Tensor of shape (bsz, max_seq_len, tokens_per_tile * max_num_tiles * max_num_images)
- aspect_ratio: Tensor of shape (bsz, max_num_images, 2)
Raises:
ValueError: if ``pad_direction`` is not one of "left" or "right".
Example:
>>> image_id = 1
>>> tokens_per_tile = 5
>>> c, h, w = 1, 1, 1
>>> batch = [
... {
... "tokens": [1, 2, 1, 3], "labels": [4, 5, 6, 7],
... "encoder_input": {
... # One image with two tiles, one image with three tiles
... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)],
... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])],
... },
... # Mask is shape (text_seq_len, tokens_per_tile * n_tiles)
... "encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)],
... },
... {
... "tokens": [1, 4], "labels": [8, 9],
... "encoder_input": {
... # One image with four tiles
... "images": [torch.ones(4, c, h, w)],
... "aspect_ratio": [torch.tensor([2, 2])],
... },
... # Mask is shape (text_seq_len, tokens_per_tile * n_tiles)
... "encoder_mask": [torch.ones(2, 5 * 4)],
... },
... ]
>>> model_inputs = padded_collate_tiled_images_and_mask(batch=batch)
>>> print(model_inputs["tokens"])
tensor([[1, 2, 1, 3],
[1, 4, 0, 0]])
>>> print(model_inputs["labels"])
tensor([[4, 5, 6, 7],
[8, 9, -100, -100]])
>>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w)
torch.Size([2, 2, 4, 1, 1, 1])
>>> print(model_inputs["encoder_mask"].shape) # (bsz, max_text_seq_len, tokens_per_tile * max_num_tiles * max_num_images)
torch.Size([2, 4, 40])
>>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2)
torch.Size([2, 2, 2])
>>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four
tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]])
>>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]])
>>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]])
>>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample
tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]])
"""
if pad_direction not in ["left", "right"]:
raise ValueError(
f"pad_direction should be one of 'left' or 'right' but found {pad_direction}"
)

# Text tokens can be handled independently by existing collaters
if pad_direction == "right":
text_only = [
{"tokens": sample["tokens"], "labels": sample["labels"]} for sample in batch
]
collated_text = padded_collate_sft(text_only, padding_idx, ignore_idx)
# For inference, we don't need to handle labels
elif pad_direction == "left":
collated_text = {
"tokens": left_pad_sequence(
[torch.tensor(x["tokens"]) for x in batch],
batch_first=True,
padding_value=padding_idx,
)
}

max_seq_len = collated_text["tokens"].shape[-1]
bsz = len(batch)

# TODO: Figure out how to make this more efficient or vectorized. Setting
# max_num_tiles beforehand will save one nested for loop but may incur more
# memory and compute costs in attention if max_num_tiles > batch_max_num_tiles

# First loop: get max number of tiles in batch
max_num_tiles = max(
image.shape[0]
for sample in batch
for image in sample["encoder_input"]["images"]
)
# Second loop: pad images and masks to max number of tiles, max text seq len in batch
batch_images = []
batch_masks = []
batch_aspect_ratios = []
for sample in batch:
sample_images = []
sample_masks = []
for image, mask in zip(
sample["encoder_input"]["images"], sample["encoder_mask"]
):
# Single image in each sample has shape (n_tiles, c, h, w)
n_tiles = image.shape[0]
# Single mask in each sample corresponds to a single image and has shape (text_seq_len, image_seq_len)
# where image_seq_len = n_tiles * tokens_per_tile
text_seq_len, image_seq_len = mask.shape
tokens_per_tile = image_seq_len // n_tiles
padding_tiles = max_num_tiles - n_tiles
padding_text = max_seq_len - text_seq_len
# Image should now have shape (max_num_tiles, c, h, w)
padded_image = F.pad(image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0)
# Mask should now have shape (max_seq_len, max_image_seq_len), where
# max_image_seq_len = max_num_tiles * tokens_per_tile
padded_mask = F.pad(
mask, (0, padding_tiles * tokens_per_tile, 0, padding_text), value=0
)
sample_images.append(padded_image)
sample_masks.append(padded_mask)
# Stack multiple images and masks per sample in num_images dimension
batch_images.append(torch.stack(sample_images))
batch_masks.append(torch.stack(sample_masks))
batch_aspect_ratios.append(torch.stack(sample["encoder_input"]["aspect_ratio"]))
# Finally, pad images, masks, aspect ratios to max number of images in batch
# (bsz, max_num_images, max_num_tiles, c, h, w)
collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0)
# (bsz, max_num_images, max_seq_len, max_image_seq_len)
collated_masks = pad_sequence(batch_masks, batch_first=True, padding_value=0)
# (bsz, max_num_images, 2)
collated_aspect_ratios = pad_sequence(
batch_aspect_ratios, batch_first=True, padding_value=1
)

# Concatenate masks for multiple images across image_seq_len dimension
concat_masks = collated_masks.view(bsz, max_seq_len, -1)

batch_dict = {
"tokens": collated_text["tokens"],
"encoder_input": {
"images": collated_images,
"aspect_ratio": collated_aspect_ratios,
},
"encoder_mask": concat_masks,
}

if "labels" in collated_text:
batch_dict["labels"] = collated_text["labels"]

return batch_dict


def padded_collate_packed(
batch: List[PACK_TYPE],
) -> Dict[str, torch.Tensor]:
Expand Down

0 comments on commit 377abc0

Please sign in to comment.