Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Images in Messages #1504

Merged
merged 26 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f63c096
[Pseudo-RFC] Images in Messages
joecummings Sep 5, 2024
da977ef
Updates
joecummings Sep 6, 2024
aab101c
<Replace this line with a title. Use 1 line only, 67 chars or less>
joecummings Sep 6, 2024
b0a6702
Updates
joecummings Sep 6, 2024
12c3c30
Stub
joecummings Sep 9, 2024
6458d1f
I AM A GOD
joecummings Sep 9, 2024
3fb44c1
Add the doggo
joecummings Sep 9, 2024
dd111ba
What have I done....
joecummings Sep 9, 2024
f87eab7
Merge remote-tracking branch 'upstream/main' into image-in-messages
joecummings Sep 9, 2024
d833db1
LINT
joecummings Sep 9, 2024
c2e00ca
Fix Llava Instruct
joecummings Sep 9, 2024
2bdb254
Last Llava test fixes
joecummings Sep 9, 2024
2f8657d
Fix the Cauldron
joecummings Sep 9, 2024
751186a
Move load_image to a utils loc
joecummings Sep 9, 2024
7c63d16
Remove unnecessary changes
joecummings Sep 9, 2024
08c8112
Cleanup
joecummings Sep 9, 2024
c9416d6
Actually use PIL images in test for formatting
joecummings Sep 9, 2024
2f17dfe
Stop
joecummings Sep 9, 2024
eb97847
Convert images_dir to Path on the backend
joecummings Sep 10, 2024
26ef618
Update docstring for format_content_with_images
joecummings Sep 10, 2024
7cf3f33
Update API ref with new functions
joecummings Sep 10, 2024
932b3d2
Update DummyTokenizer to account for images
joecummings Sep 10, 2024
8007afd
Better docs rendering for data utils
joecummings Sep 10, 2024
5a74b30
More formatting + updating Message test
joecummings Sep 10, 2024
c90b766
Whoops
joecummings Sep 10, 2024
6fe061b
Update torchtune/data/_utils.py
joecummings Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api_ref_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,5 @@ Miscellaneous helper functions used in modifying data.

validate_messages
truncate
load_image
format_content_with_images
Binary file added tests/assets/dog_on_skateboard.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 5 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,14 @@ def tokenize_messages(
return tokenized_messages, mask

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
messages = sample.pop("messages")
messages: List[Message] = sample.pop("messages")
images = []
for message in messages:
images += message.get_media()
tokens, mask = self.tokenize_messages(messages)
sample["tokens"] = tokens
sample["mask"] = mask
sample["images"] = images
return sample

@property
Expand Down
124 changes: 109 additions & 15 deletions tests/torchtune/data/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os

import pytest
from PIL import Image

from tests.common import ASSETS
from torchtune.data import (
format_content_with_images,
Message,
PromptTemplate,
split_text_by_image_tag,
truncate,
validate_messages,
)
from torchtune.data._utils import _get_prompt_template
from torchtune.data._utils import _get_prompt_template, load_image
from torchtune.models.llama2 import Llama2ChatTemplate


Expand Down Expand Up @@ -98,47 +103,136 @@ def test_validate_messages():
validate_messages(messages)


def test_split_text_by_image_tag():
def test_format_content_with_images():
test_image_1 = Image.new(mode="RGB", size=(4, 4))
test_image_2 = Image.new(mode="RGB", size=(4, 4))
test_image_3 = Image.new(mode="RGB", size=(4, 4))

# Test single image tag in the middle
text = "hello <image>world"
assert split_text_by_image_tag(text, "<image>") == [
assert format_content_with_images(
text,
image_tag="<image>",
images=[test_image_1],
) == [
{"type": "text", "content": "hello "},
{"type": "image"},
{"type": "image", "content": test_image_1},
{"type": "text", "content": "world"},
]

# Test multiple image tags and image tag in beginning
text = "[image]hello [image]world"
assert split_text_by_image_tag(text, "[image]") == [
{"type": "image"},
assert format_content_with_images(
text,
image_tag="[image]",
images=[test_image_1, test_image_2],
) == [
{"type": "image", "content": test_image_1},
{"type": "text", "content": "hello "},
{"type": "image"},
{"type": "image", "content": test_image_2},
{"type": "text", "content": "world"},
]

# Test an image tag that is not present in the text
text = "hello world"
assert split_text_by_image_tag(text, "asdfghjkl;") == [
assert format_content_with_images(text, image_tag="asdfghjkl;", images=[]) == [
{"type": "text", "content": "hello world"}
]

# Test consecutive image tags
text = "<image><image>hello <image>world"
assert split_text_by_image_tag(text, "<image>") == [
{"type": "image"},
{"type": "image"},
assert format_content_with_images(
text,
image_tag="<image>",
images=[test_image_1, test_image_2, test_image_3],
) == [
{"type": "image", "content": test_image_1},
{"type": "image", "content": test_image_2},
{"type": "text", "content": "hello "},
{"type": "image"},
{"type": "image", "content": test_image_3},
{"type": "text", "content": "world"},
]

# Test image tag at the end
text = "hello <image>"
assert split_text_by_image_tag(text, "<image>") == [
assert format_content_with_images(
text,
image_tag="<image>",
images=[test_image_1],
) == [
{"type": "text", "content": "hello "},
{"type": "image"},
{"type": "image", "content": test_image_1},
]

# Test errors when the number of images does not match the number of image tags
text = "hello <image>world"
with pytest.raises(
ValueError,
match="does not match number of image tags",
):
format_content_with_images(
text, image_tag="<image>", images=[test_image_1, test_image_2]
)


def test_load_image(monkeypatch, tmp_path):
tmp_image = str(ASSETS / "dog_on_skateboard.jpg")

# Test loading from local file
image = load_image(tmp_image)
assert isinstance(image, Image.Image)
assert image.size == (580, 403)

# Test loading from remote file
# Mock the urlopen function to return a BytesIO object
def mock_urlopen(url):
return open(tmp_image, "rb")

monkeypatch.setattr("urllib.request.urlopen", mock_urlopen)
image = load_image("http://example.com/test_image.jpg")
assert isinstance(image, Image.Image)
assert image.size == (580, 403)

# Test that a ValueError is raised when the image path is invalid
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
load_image("invalid_path")

# Test a temporary file with invalid image data
image_path = tmp_path / "test_image.jpg"
with open(image_path, "w") as f:
f.write("Invalid image data")

# Test that a ValueError is raised when the image data is invalid
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
load_image(str(image_path))

# Test that a ValueError is raised when there is an HTTP error
# Mock the urlopen function to raise an exception
def mock_urlopen(url):
raise Exception("Failed to load image")

monkeypatch.setattr("urllib.request.urlopen", mock_urlopen)
with pytest.raises(ValueError, match="Failed to load image"):
load_image("http://example.com/test_image.jpg")

# Test that a ValueError is raised when there is an IO error
# Create a temporary file that cannot be read
image_path = tmp_path / "test_image.jpg"
with open(image_path, "w") as f:
f.write("Test data")
os.chmod(image_path, 0o000) # Remove read permissions
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
load_image(str(image_path))
os.chmod(image_path, 0o644) # Restore read permissions

# Test that a ValueError is raised with invalid image data is read
# Create a temporary file with invalid image data
image_path = tmp_path / "test_image.jpg"
with open(image_path, "wb") as f:
f.write(b"Invalid image data")
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
load_image(str(image_path))


def test_get_prompt_template():
template = _get_prompt_template("torchtune.models.llama2.Llama2ChatTemplate")
Expand Down
18 changes: 14 additions & 4 deletions tests/torchtune/data/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.

import pytest

from PIL import Image
from tests.test_utils import (
assert_dialogue_equal,
CHAT_SAMPLE,
Expand All @@ -26,17 +28,21 @@ def text_message(self):
return Message(role="user", content="hello world")

@pytest.fixture
def image_message(self):
def test_image(self):
return Image.new(mode="RGB", size=(4, 4))

@pytest.fixture
def image_message(self, test_image):
return Message(
role="user",
content=[
{"type": "text", "content": "hello"},
{"type": "image"},
{"type": "image", "content": test_image},
{"type": "text", "content": " world"},
],
)

def test_message_validation(self, text_message):
def test_message_validation(self, text_message, test_image):
message = text_message
assert message.role == "user"
assert message.content == [{"type": "text", "content": "hello world"}]
Expand All @@ -53,7 +59,7 @@ def test_message_validation(self, text_message):
):
message = Message(
role="user",
content=[{"type": "image"}],
content=[{"type": "image", "content": test_image}],
ipython=True,
)

Expand All @@ -69,6 +75,10 @@ def test_contains_media(self, text_message, image_message):
assert not text_message.contains_media
assert image_message.contains_media

def test_get_media(self, text_message, image_message, test_image):
assert text_message.get_media() == []
assert image_message.get_media() == [test_image]

def test_text_content(self, text_message, image_message):
assert text_message.text_content == "hello world"
assert image_message.text_content == "hello world"
Expand Down
27 changes: 23 additions & 4 deletions tests/torchtune/datasets/multimodal/test_llava_instruct_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from collections import Counter
from unittest.mock import patch

import PIL

import pytest
from datasets import Dataset

Expand All @@ -21,11 +23,22 @@ class TestLLaVAInstructDataset:
def tokenizer(self):
return DummyTokenizer()

@pytest.fixture
def test_image_pil(self):
return PIL.Image.new(mode="RGB", size=(4, 4))

@patch("torchtune.datasets._sft.load_dataset")
def test_label_no_masking(self, load_dataset, tokenizer):
@patch("torchtune.datasets.multimodal._llava_instruct.load_image")
def test_label_no_masking(
self, load_image, load_dataset, tokenizer, test_image_pil
):
"""
Test whether the input and the labels are correctly created when the input is not masked.

WARNING: careful with these mocks, they are applied in bottom up order
"""
# mock the call to load_image
load_image.return_value = test_image_pil

# mock the call to HF datasets
load_dataset.return_value = Dataset.from_list(
Expand Down Expand Up @@ -55,6 +68,7 @@ def test_label_no_masking(self, load_dataset, tokenizer):
model_transform=tokenizer,
train_on_input=True,
)

input, labels, images = ds[0]["tokens"], ds[0]["labels"], ds[0]["images"]

expected_count = {
Expand All @@ -76,13 +90,18 @@ def test_label_no_masking(self, load_dataset, tokenizer):

assert Counter(input) == expected_count
assert Counter(labels) == expected_count
assert images == "test_image.jpg"
assert images == [test_image_pil]

@patch("torchtune.datasets._sft.load_dataset")
def test_label_masking(self, load_dataset, tokenizer):
@patch("torchtune.datasets.multimodal._llava_instruct.load_image")
def test_label_masking(self, load_image, load_dataset, tokenizer, test_image_pil):
"""
Test whether the input and the labels are correctly created when the input is masked.

WARNING: careful with these mocks, they are applied in bottom up order
"""
# mock the call to load_image
load_image.return_value = test_image_pil

# mock the call to HF datasets
load_dataset.return_value = Dataset.from_list(
Expand Down Expand Up @@ -133,4 +152,4 @@ def test_label_masking(self, load_dataset, tokenizer):

assert Counter(input) == expected_count
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 11
assert images == "test_image.jpg"
assert images == [test_image_pil]
Loading
Loading