Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] MM Eval Mask Sizes #1920

Merged
merged 3 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion recipes/dev/generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def generate(self, cfg: DictConfig):
batch = {}
if is_multimodal_input:
batch = padded_collate_tiled_images_and_mask(
[model_inputs], pad_direction="left", pad_max_images=1
[model_inputs],
pad_direction="left",
pad_max_images=1,
pad_max_tiles=self.model_transform.max_num_tiles,
)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
prompt = batch.pop("tokens").to(self._device)
Expand Down
1 change: 1 addition & 0 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def tok_batch_multimodal_encode(
all_encoded_messages,
pad_direction="left",
pad_max_images=self._max_images_per_sample,
pad_max_tiles=self._transform.max_num_tiles,
)
utils.batch_to_device(tok_batch, self.device)

Expand Down
40 changes: 28 additions & 12 deletions tests/torchtune/data/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,41 @@ def test_batch_pad_sequence(self):


class TestPaddedCollateTiledImagesAndMask:
img_shape = 1, 1, 1
tokens_per_tile = 5

@pytest.fixture
def batch(self):
c, h, w = self.img_shape
s = self.tokens_per_tile
return [
{
"tokens": [1, 2, 1, 3],
"labels": [4, 5, 6, 7],
"encoder_input": {
"images": [torch.ones(2, 1, 1, 1), torch.ones(3, 1, 1, 1)],
"images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)],
"aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])],
},
"encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)],
"encoder_mask": [torch.ones(4, s * 2), torch.ones(4, s * 3)],
},
{
"tokens": [1, 4],
"labels": [8, 9],
"encoder_input": {
"images": [torch.ones(4, 1, 1, 1)],
"images": [torch.ones(4, c, h, w)],
"aspect_ratio": [torch.tensor([2, 2])],
},
"encoder_mask": [torch.ones(2, 5 * 4)],
"encoder_mask": [torch.ones(2, s * 4)],
},
]

def test_right_pad_sequence(self, batch):
actual = padded_collate_tiled_images_and_mask(
batch=batch, padding_idx=0, ignore_idx=-100, pad_direction="right"
)
imgs, tiles = actual["encoder_input"]["images"].shape[1:3]
seq_len = actual["encoder_mask"].shape[-1]
assert imgs * tiles * self.tokens_per_tile == seq_len

mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1)
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1)
Expand Down Expand Up @@ -126,28 +134,36 @@ def test_left_pad_sequence(self, batch):
ignore_idx=-100,
pad_direction="left",
pad_max_images=4,
pad_max_tiles=5,
)
imgs, tiles = actual["encoder_input"]["images"].shape[1:3]
seq_len = actual["encoder_mask"].shape[-1]
assert 5 * 4 * self.tokens_per_tile == seq_len

mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1)
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1)
# pad 3 extra tiles
mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 5 * 3)], dim=1)
# pad 2 extra tiles
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5 * 2)], dim=1)
# Left pad text tokens
mask_3 = torch.concat([torch.zeros(2, 20), torch.ones(2, 5 * 4)], dim=0)
mask_3 = F.pad(mask_3, (0, 5), value=0) # pad 5th tile
sample_1 = torch.stack([mask_1, mask_2])
sample_2 = torch.stack([mask_3, torch.zeros(4, 20)])
sample_2 = torch.stack([mask_3, torch.zeros(4, 25)])
expected_mask = torch.stack([sample_1, sample_2]).view(2, 4, -1)
expected_mask = F.pad(expected_mask, (0, 40), value=0)
expected_mask = F.pad(expected_mask, (0, 50), value=0)

expected = {
"tokens": torch.tensor([[1, 2, 1, 3], [0, 0, 1, 4]]),
"encoder_input": {
"images": torch.tensor(
[
[
[[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]],
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]],
[[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]],
],
[
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]]],
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]],
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
],
]
),
Expand Down
4 changes: 1 addition & 3 deletions tests/torchtune/modules/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


IMAGE_TOKEN_ID = 1
MAX_NUM_TILES = 4


class TestVisionCrossAttentionMask:
Expand Down Expand Up @@ -54,7 +53,6 @@ def cross_attn_mask_transform(self, tile_size, patch_size):
tile_size=tile_size,
patch_size=patch_size,
image_token_id=IMAGE_TOKEN_ID,
max_num_tiles=MAX_NUM_TILES,
)

def test_get_image_attention_intervals(self, cross_attn_mask_transform, tokens):
Expand Down Expand Up @@ -89,7 +87,7 @@ def test_inference_call(
sample.update(dummy_kwargs)
actual = cross_attn_mask_transform(sample, inference=True)
expected = [
torch.zeros(len(tokens), image_num_tokens * 2, dtype=torch.bool)
torch.zeros(len(tokens), image_num_tokens, dtype=torch.bool)
for _ in range(len(images))
]
expected[0][2:6, :image_num_tokens] = True
Expand Down
3 changes: 2 additions & 1 deletion torchtune/data/_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,8 @@ def padded_collate_tiled_images_and_mask(
if pad_max_images is not None:
_, _, img_seq = concat_masks.shape
concat_masks = F.pad(
concat_masks, (0, pad_max_images * image_seq_len - img_seq)
concat_masks,
Copy link
Contributor

@RdoubleA RdoubleA Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does the pad to max images happen? this is just padding the masks to max num images? and would pad direction affect image padding at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don’t actually need to add them or you’d waste compute on them. If you have a I cache for 7 images, then you want to mask out the additional images. It’s similar to how you mask extra token positions during inference but don’t add padding tokens.

(0, pad_max_images * max_num_tiles * tokens_per_tile - img_seq),
)

batch_dict = {
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/llama3_2_vision/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ def __init__(
tile_size=tile_size,
patch_size=patch_size,
image_token_id=self.tokenizer.image_id,
max_num_tiles=max_num_tiles,
)

self.stop_tokens = self.tokenizer.stop_tokens
self.max_seq_len = max_seq_len
self.max_num_tiles = max_num_tiles
self.image_seq_len = max_num_tiles * (self.xattn_mask.patches_per_tile + 1)
self.prompt_template = prompt_template
self.pad_id = self.tokenizer.pad_id
Expand Down
13 changes: 2 additions & 11 deletions torchtune/modules/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List, Mapping, Optional, Protocol
from typing import Any, List, Mapping, Protocol

import torch

Expand Down Expand Up @@ -57,21 +57,17 @@ class VisionCrossAttentionMask(Transform):
E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches
with shape (40, 40) each.
image_token_id (int): Token ID of the image special token.
max_num_tiles (Optional[int]): Maximum number of tiles in an image, used to
pad mask during inference. Defaults to None
"""

def __init__(
self,
tile_size: int,
patch_size: int,
image_token_id: int,
max_num_tiles: Optional[int] = None,
):
patch_grid_size = tile_size // patch_size
self.patches_per_tile = patch_grid_size**2
self.image_token_id = image_token_id
self.max_num_tiles = max_num_tiles

def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]:
"""
Expand Down Expand Up @@ -163,9 +159,6 @@ def __call__(
# which can vary based on number of tiles since they are not yet tile padded.
# The masks are padded and concatenated together in the batch collator
text_seq_len = len(tokens)
max_image_size = None
if inference and self.max_num_tiles is not None:
max_image_size = self.max_num_tiles * (self.patches_per_tile + 1)
masks = []
for image_num, interval in enumerate(intervals):
# Identify what part of text sequence should be attended
Expand All @@ -178,9 +171,7 @@ def __call__(
# to a single image, so text tokens attend to all the image's tokens.
# The mask is text_seq_len x mask_image_size if defined, otherwise
# it uses current text/image sequence lengths.
mask = torch.zeros(
text_seq_len, max_image_size or image_seq_len, dtype=torch.bool
)
mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool)
mask[start:end, :image_seq_len] = True
masks.append(mask)

Expand Down
Loading