Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multimodal collater with interleaved image, cross-attention mask padding #1156

Merged
merged 21 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Collaters used to collect samples into batches and handle any padding.
:nosignatures:

padded_collate
padded_collate_tiled_images_with_cross_attention
padded_collate_sft
padded_collate_dpo
left_pad_sequence
Expand Down
51 changes: 51 additions & 0 deletions tests/torchtune/data/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
padded_collate,
padded_collate_dpo,
padded_collate_sft,
padded_collate_tiled_images_with_cross_attention,
)


Expand Down Expand Up @@ -47,6 +48,56 @@ def test_batch_pad_sequence(self):
padded_label, torch.tensor([10, ignore_idx, ignore_idx])
)

def test_padded_collate_tiled_images_with_cross_attention(self):
batch = [
{
"tokens": [1, 2, 1, 3],
"labels": [4, 5, 6, 7],
"images": [torch.ones(2, 1, 1, 1), torch.ones(3, 1, 1, 1)],
"encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)],
"aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 2])],
},
{
"tokens": [1, 4],
"labels": [8, 9],
"images": [torch.ones(4, 1, 1, 1)],
"encoder_mask": [torch.ones(2, 5 * 4)],
"aspect_ratio": [torch.tensor([1, 2])],
},
]
actual = padded_collate_tiled_images_with_cross_attention(
batch=batch, padding_idx=0, ignore_idx=-100
)

mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1)
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1)
mask_3 = torch.concat([torch.ones(2, 5 * 4), torch.zeros(2, 20)], dim=0)
sample_1 = torch.stack([mask_1, mask_2])
sample_2 = torch.stack([mask_3, torch.zeros(4, 20)])
expected_mask = torch.stack([sample_1, sample_2])

expected = {
"tokens": torch.tensor([[1, 2, 1, 3], [1, 4, 0, 0]]),
"labels": torch.tensor([[4, 5, 6, 7], [8, 9, -100, -100]]),
"images": torch.tensor(
[
[
[[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]],
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]],
],
[
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]]],
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
],
]
),
"encoder_mask": expected_mask,
"aspect_ratio": torch.tensor([[[1, 2], [1, 2]], [[1, 2], [1, 1]]]),
}

for k in expected:
torch.testing.assert_close(actual[k], expected[k])


class TestLeftPadSequence:
def test_left_pad_sequence(self):
Expand Down
2 changes: 2 additions & 0 deletions torchtune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
padded_collate,
padded_collate_dpo,
padded_collate_sft,
padded_collate_tiled_images_with_cross_attention,
)
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._converters import get_openai_messages, get_sharegpt_messages
Expand Down Expand Up @@ -58,4 +59,5 @@
"padded_collate_dpo",
"left_pad_sequence",
"padded_collate",
"padded_collate_tiled_images_with_cross_attention",
]
138 changes: 137 additions & 1 deletion torchtune/data/_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -214,6 +214,142 @@ def padded_collate_sft(
return {"tokens": input_ids.long(), "labels": labels.long()}


def padded_collate_tiled_images_with_cross_attention(
batch: List[Dict[str, Any]],
padding_idx: int = 0,
ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX,
) -> Dict[str, torch.Tensor]:
"""Pad a batch of text sequences, tiled image tensors, aspect ratios,
and cross attention masks.

``batch`` is expected to be a list of sample dicts containing the following:
- "tokens": List[int] of length text_seq_len, varies across samples
- "labels": List[int] of length text_seq_len, varies across samples
- "images": List[Tensor], each with shape (n_tiles, c, h, w)
- "encoder_mask": List[Tensor], each with shape (text_seq_len, image_seq_len)
- "aspect_ratio": List[Tensor], each with shape (h_ratio, w_ratio)
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved

This collater does the following:
(1) Pad text sequence and encoder mask to the longest sequence length in the batch
(2) Pad image tensors in the tile dimension with zeros to the largest number
of tiles in the batch
(3) Add empty images of zeros to samples up to max number of images in the batch
(4) Pad aspect ratios with (1,1) for all added padding images

Args:
batch (List[Dict[str, Any]]): A list of sample dicts containing tokens,
labels, images, encoder_mask, and aspect_ratio.
padding_idx (int): Padding index for input token ids. Defaults to 0.
ignore_idx (int): Padding index for labels. Defaults to -100.

Returns:
Dict[str, Tensor]: Collated tokens, labels, images, encoder_mask, aspect_ratio tensors.
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved

Example:
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
>>> image_id = 1
>>> tokens_per_tile = 5
>>> c, h, w = 1, 1, 1
>>> batch = [
... {
... "tokens": [1, 2, 1, 3], "labels": [4, 5, 6, 7],
... # One image with two tiles, one image with three tiles
... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)],
... "encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)],
... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 2])],
... },
... {
... "tokens": [1, 4], "labels": [8, 9],
... # One image with four tiles
... "images": [torch.ones(4, c, h, w)],
... "encoder_mask": [torch.ones(2, 5 * 4)],
... "aspect_ratio": [torch.tensor([1, 2])],
... },
... ]
>>> model_inputs = padded_collate_vision_text(batch=batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name doesn't match the current name. I actually prefer padded_collate_vision_text as it's more straight forward and we can either generalize this function or split and rename as we get more vision_text models in the future.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+ 1

>>> print(model_inputs["tokens"])
tensor([[1, 2, 1, 3],
[1, 4, 0, 0]])
>>> print(model_inputs["labels"])
tensor([[4, 5, 6, 7],
[8, 9, -100, -100]])
>>> print(model_inputs["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w)
torch.Size([2, 2, 4, 1, 1, 1])
>>> print(model_inputs["encoder_mask"].shape) # (bsz, max_num_images, max_num_tiles, tokens_per_tile * max_num_tiles)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should actually be [2, 4, 40] since cross attention is text vs image sequence and the image sequence is num_imagesnum_tilestokens_per_tile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but what about batch size?

torch.Size([2, 2, 4, 20])
>>> print(model_inputs["aspect_ratio"].shape) # (bsz, max_num_images, 2)
torch.Size([2, 2, 2])
>>> print(model_inputs["images"][0, 0, ...]) # Image with two tiles got padded to four
tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]])
>>> print(model_inputs["images"][0, 1, ...]) # Image with three tiles got padded to four
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]])
>>> print(model_inputs["images"][1, 0, ...]) # Image with four tiles did not get padded
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]])
>>> print(model_inputs["images"][1, 1, ...]) # Extra padding image was added to second sample
tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]])
"""
# Text tokens can be handled independently by existing collater
text_only = [
{"tokens": sample["tokens"], "labels": sample["labels"]} for sample in batch
]
collated_text = padded_collate_sft(text_only, padding_idx, ignore_idx)
max_seq_len = collated_text["tokens"].shape[-1]

# TODO: Figure out how to make this more efficient or vectorized. Setting
Copy link
Contributor

@felipemello1 felipemello1 Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didnt think too much about it, but maybe:

  1. do a first pass to check the max of each dimension.
  2. create a tensor with all zeros. Pre allocating should simplify all the padding.
  3. Add the input to the tensor correct line: eg. tensor[0] += sample

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-allocating would definitely simplify the code. I would still need to loop through each individual image though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will leave this as a follow-up though in the interest of time

# max_num_tiles beforehand will save one nested for loop but may incur more
# memory and compute costs in attention if max_num_tiles > batch_max_num_tiles

# First loop: get max number of tiles in batch
max_num_tiles = max(
image.shape[0] for sample in batch for image in sample["images"]
)
# Second loop: pad images and masks to max number of tiles, max text seq len in batch
batch_images = []
batch_masks = []
batch_aspect_ratios = []
for sample in batch:
sample_images = []
sample_masks = []
for image, mask in zip(sample["images"], sample["encoder_mask"]):
# Single image in each sample has shape (n_tiles, c, h, w)
n_tiles = image.shape[0]
# Single mask in each sample corresponds to a single image and has shape (text_seq_len, image_seq_len)
# where image_seq_len = n_tiles * tokens_per_tile
text_seq_len, image_seq_len = mask.shape
tokens_per_tile = image_seq_len // n_tiles
padding_tiles = max_num_tiles - n_tiles
padding_text = max_seq_len - text_seq_len
# Image should now have shape (max_num_tiles, c, h, w)
padded_image = F.pad(image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0)
# Mask should now have shape (max_seq_len, max_image_seq_len), where
# max_image_seq_len = max_num_tiles * tokens_per_tile
padded_mask = F.pad(
mask, (0, padding_tiles * tokens_per_tile, 0, padding_text), value=0
)
sample_images.append(padded_image)
sample_masks.append(padded_mask)
# Stack multiple images and masks per sample in num_images dimension
batch_images.append(torch.stack(sample_images))
batch_masks.append(torch.stack(sample_masks))
batch_aspect_ratios.append(torch.stack(sample["aspect_ratio"]))
# Finally, pad images, masks, aspect ratios to max number of images in batch
# (bsz, max_num_images, max_num_tiles, c, h, w)
collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0)
# (bsz, max_num_images, max_seq_len, max_image_seq_len)
collated_masks = pad_sequence(batch_masks, batch_first=True, padding_value=0)
# (bsz, max_num_images, 2)
collated_aspect_ratios = pad_sequence(
batch_aspect_ratios, batch_first=True, padding_value=1
)

return {
"tokens": collated_text["tokens"],
"labels": collated_text["labels"],
"images": collated_images,
"encoder_mask": collated_masks,
"aspect_ratio": collated_aspect_ratios,
}


def padded_collate_dpo(
batch: List[Dict[str, List[int]]],
padding_idx: int = 0,
Expand Down
Loading