Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Images in Messages #1504

Merged
merged 26 commits into from
Sep 10, 2024
Merged

Images in Messages #1504

merged 26 commits into from
Sep 10, 2024

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented Sep 5, 2024

Put image content in Message (and other suspicious changes)

Background

No snappy intro, I'm just tired.

This all started in considering the inference UX for multimodal models. The ideal state would be something like this in the config:

prompt:
	- system: You are a helpful assistant. 
	- user:
		- image: https://en.wikipedia.org/wiki/Wikipedia:Images#/media/File:Spinifex_Pigeon_0A2A1585.jpg
		- text: What's in this image? 

Then the user can call the generate recipe:

python recipes/generate.py --config recipes/configs/multimodal_generation.yaml
This is an image of a bird sitting on a tree. 

But, what is this getting at?

The concept of a "Message" should contain all the information needed for an interaction with a model INCLUDING images. Before, images were split off from Messages and a placeholder was put in the Message object so the model knew where to inject the content:

message_content = [
	{"type": "image"},
	{"type": "text", "content": "What is in this image?"},
]

images = ["https://en.wikipedia.org/wiki/Wikipedia:Images#/media/File:Spinifex_Pigeon_0A2A1585.jpg"]

After these changes, the model now has ALL the information it needs within the message content itself:

message_content = [
	{"type": "image", "content":  <PIL.Image.Image>},
	{"type": "text", "content": "What is in this image?"},
]

The core of the change is relatively simple; however, we already backed the split images work into all our multimodal stuff, so that all had to be reworked as well. See below for a full changelog.

Changelog

  • Added "content" field for PIL type in message content
  • Added a get_media function to easily get all the image content from a Message
  • Updated the Flamingo transform to use this new functionality
  • Added a function to load a PIL image from local drive or remote source
  • Added a new function to format content with images properly
  • Update Cauldron and Llava Instruct datasets with new changes
  • Tests!!! Including uploading an asset to test if it can actually load an image properly

Testing

(joe-torchtune) [[email protected] ~/projects/joe-torchtune (3fb44c161)]$ python -m pytest tests/torchtune/
...
 =============================================== 467 passed, 5 skipped, 8 warnings in 119.40s (0:01:59) ===============================================

FAQs

Why did you make it possible to load from remote or local sources? This is a feature usually demonstrated for inference like in Phi3 Vision and Qwen2 vision. We don't want to force users to have to download the image to device to start, so this handles both.

Why not convert to PIL image within Message? I wanted to do this originally; however, Message is a class used for text and multimodal models. Right now, it's possible to train and run inference without downloading extra things needed for multimodal like PIL and torchvision. This keeps the library lightweight. In keeping with that, if we make the specific multimodal model transforms load and use PIL, we keep this usage separate. I'd be open to changing if PIL and Multimodal becomes so important that it's used everywhere, but right now it's not. I did it, sue me.

Copy link

pytorch-bot bot commented Sep 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1504

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6fe061b with merge base 66590b4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 5, 2024
Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would mean that model transform just takes in a List[Message] directly instead of a sample dictionary, or do we still plan to pass in sample["messages"] to model transforms?

@@ -89,6 +89,8 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
role="system", content=self.new_system_prompt, masked=True, eot=True
)
)

# Add in image stuffs / load from file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am leaning towards handling image loading in the model transform and keeping a URL here, so we don't need to pass in a heavy PIL image everywhere Messages is used. We only load in the image when it's absolutely needed. For tokenize_messages, messages would be more lightweight since it won't contain the actual image.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re passing around PIL images, do you mean that there are implications for memory to doing something like this? Cause if we are just passing stuff by reference it shouldn't be an issue, right?

If there are a bunch of operations that we want to apply to messages that'd be more convenient to have the file path for, that'd be one thing. But it seems to me like most (nontrivial) functions we're gonna call on Messages containing images would be those that act on the image and not the path.

In that case I'd be inclined to have things as a PIL because it's the raw datatype of an image. It's analogous to the raw string we have in the content field of a Message with text type. I think {"text" -> str, "image" -> PIL} is a lot more natural for a user than {"text" -> str, "image" -> Path}. But lmk if I'm missing the point here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment in FAQs. Generally, I agree, but I'm making the tradeoff for "weight" of dependencies.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, it's possible to train and run inference without downloading extra things needed for multimodal like PIL and torchvision

I think Joe's point here is quite valid. I agree having it PIL image by default is more intuitive but as of now multimodal is not prevalent enough as a use case.

@joecummings joecummings changed the title [Pseudo-RFC] Images in Messages Images in Messages Sep 9, 2024
@joecummings joecummings marked this pull request as ready for review September 9, 2024 17:55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

totally tubular

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Requests changes"

😭

@@ -55,7 +55,8 @@ def test_label_no_masking(self, load_dataset, tokenizer):
model_transform=tokenizer,
train_on_input=True,
)
input, labels, images = ds[0]["tokens"], ds[0]["labels"], ds[0]["images"]

input, labels = ds[0]["tokens"], ds[0]["labels"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should have DummyTokenizer load in images from the message (to replicate the model transform) and then check that the image tensor is returned here as expected. that way it would be a good e2e test of passing the images through the messages into the model transform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow. Our tokenizers (model transforms) take in images through the messages, but they do not return them afterwards.

Looking at Flamingo transform for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're using DummyTokenizer as the substitute for model transform here. Really it should be called DummyModelTransform. My point is we should test that images are passed through from the dataset to messages and are processed correctly. Right now with these changes we are not checking images at all, only the image special tokens.

if isinstance(content, str)
else content
)
self.content = self._convert_to_list_of_dict(content)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice

{"type": "image"},
{
"type": "image",
"content": sample[self._column_map["images"]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be a PIL image. can image content still double as PIL image or path string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean it can - but it should be consistent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn, if this is coming in as a PIL image, then it might make sense to go ahead with the changes to load image into Message as a PIL image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ha... thinking about it a bit more, I suppose there's no actual memory hit. if you're using multimodal datasets you will certainly need to load the image at some point. then again, changing my mind after five minutes hurts my pride just a little

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really want a message to be Union[Path, str, PIL.Image]

That seems like a recipe for confusion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ALRIGHT F IT, here's the plan:

I was a big ol dumb dumb. This stuff is gunna go. Messages will now hold PIL images. This will be loosely typed in Messages so that we don't need to actually import PIL and Messages can stay free of that dependency hell. Then, the onus for providing a proper PIL image to the Message class is on the dataset builder and the generation recipe. For generation, we could just have a separate one for multimodal or require PIL for everything. For dataset builder, everything is in the multimodal folder., which protects us from imports (thank for the forward thinking @RdoubleA)

This makes the most logical sense IMO.

data entry points (datasets, generation) load in the PIL image -> feed to message in proper format -> message is used by model transform however it wants.

Okay? Okay.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk what just happened here but I'm on board

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ALRIGHT F IT

you right, I'm on board

torchtune/datasets/multimodal/_llava_instruct.py Outdated Show resolved Hide resolved
torchtune/models/clip/_transform.py Outdated Show resolved Hide resolved
@RdoubleA
Copy link
Contributor

RdoubleA commented Sep 9, 2024

No snappy intro, I'm just tired.

I feel that

@joecummings joecummings requested a review from RdoubleA September 9, 2024 22:42


# TODO: point to Flamingo model transform as an example
def llava_instruct_dataset(
model_transform: Transform,
*,
source: str = "liuhaotian/LLaVA-Instruct-150K",
images_dir: str = "coco/",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love some input here. Right now this is a string b/c I'm anticipating usage from a config. However, the proper typing would be something like Path (Pathlib) b/c it would be much more tolerant to errant forward-slashes and plays better with combination of sub paths.

But then, I'll have to do some conversion on the backside to do that.

Thoughts? @RdoubleA @ebsmothers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we convert to Path on the backside but still maintain a string parameter, that should handle any trailing/incorrect slashes right? If so I would do that

PIL.Image.Image: The loaded image.
"""
# Hackily import PIL to avoid burdensome import in the main module
# TODO: Fix this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this import already pretty prevalent throughout the library? or are you avoiding some other issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only prevalent if doing MM. Right now, you can completely run all our text stuff without the need to import torchvision, PIL, etc...

I'd love to keep it that way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the fix here? Honestly I think this is an OK solution

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be putting multimodal stuff in a subfolder likely.

@@ -53,7 +90,7 @@ def split_text_by_image_tag(content: str, image_tag: str) -> List[Dict[str, str]
"role": "system" | "user" | "assistant",
"content":
[
{"type": "image"},
{"type": "image", "content": "path/to/image1.png"},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be PIL image

Copy link
Contributor Author

@joecummings joecummings Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u need to be a pil image 😡

>>> content = format_content_with_media(
... "<|image|>hello <|image|>world",
... image_tag="<|image|>",
... images=[<PIL.Image.Image>, <PIL.Image.Image>"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
... images=[<PIL.Image.Image>, <PIL.Image.Image>"]
... images=[<PIL.Image.Image>, <PIL.Image.Image>]

return image


def format_content_with_images(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe the name insert_images_in_content is more clear on what the function is doing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but it also does the splitting?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

splitting so we can insert the images in the correct location :) but, your call


def format_content_with_images(
content: str, *, image_tag: str, images: List["PIL.Image.Image"]
) -> List[Dict[str, str]]:
"""
Given a raw text string, split by the specified ``image_tag``
and form into list of dictionaries to be used in the ``Message`` content
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: while you're here, could you describe what happens if image tag is not found in the content string? IIRC, it's just a no-op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, if image_tag and images are both none, it's a no-op. But otherwise, it'll error b/c if there's images, it expects an image_tag.

@@ -54,6 +55,8 @@ class LlavaInstructToMessages(Transform):
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
images_dir (str): path to the directory containing the images. User is expected to download the COCO dataset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we should have a default. this should be explicitly defined by the user, otherwise they may use the builder/transform and forget to specify this and get confused when it break

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the default for something like this where it's for a very specific dataset.

pil_image = load_image(
self.images_dir + sample[self._column_map["image"]]
)
content = format_content_with_images(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this reads really nicely now

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heroic effort, thanks for this. I have no other concerns

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 95.62044% with 6 lines in your changes missing coverage. Please review.

Project coverage is 71.10%. Comparing base (66590b4) to head (5a74b30).

Files with missing lines Patch % Lines
torchtune/models/flamingo/_transform.py 0.00% 6 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1504       +/-   ##
===========================================
+ Coverage   27.22%   71.10%   +43.88%     
===========================================
  Files         286      286               
  Lines       13828    13925       +97     
===========================================
+ Hits         3764     9901     +6137     
+ Misses      10064     4024     -6040     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left one comment on a potential change but otherwise it looks good to go.

from PIL import Image

# If pointing to remote source, try to load to local
if isinstance(image_loc, str) and image_loc.startswith("http"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be more robust to use urllib.parse to check if the string is a url. I don't think http is required to use urlopen.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How? I can parse it, but urllib.parse parses it even if it's not a valid URL.

Limiting to http and https is restricting, sure, but we can relax when needed.

Comment on lines +77 to +79
assert isinstance(
content, list
), f"content must be of type List[Dict[str, Any]], got {content}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we don't wanna do it here, but is there somewhere that we can/should validate that we have a PIL image when type="image"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_messages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh, I cannot do this without a nested import of PIL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK fine to leave it as is then

if role == "system" and self.new_system_prompt is not None:
continue
content = split_text_by_image_tag(message["value"], "<image>")
if role == "user":
image_path = sample[self._column_map["image"]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this will be present in every sample in the dataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup

PIL.Image.Image: The loaded image.
"""
# Hackily import PIL to avoid burdensome import in the main module
# TODO: Fix this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the fix here? Honestly I think this is an OK solution

torchtune/data/_utils.py Outdated Show resolved Hide resolved
Comment on lines +82 to +88
raise ValueError(f"Failed to load image from {image_loc}") from e

# Open the local image as a PIL image
try:
image = Image.open(image_loc)
except Exception as e:
raise ValueError(f"Failed to open image as PIL Image from {image_loc}") from e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: are these really ValueErrors? (Or are you doing the thing I also do and lazily using ValueError as a catch-all for things that are not actually ValueErrors)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WRT the first one, no it could be several different errors. Trying to call a remote URL always has plenty of exception types. I could just leave the error as is and let the user deal with it, but that seems messy.

WRT the second value error, more often than not, it will be a ValueError as the provided value is unable to be opened by the PIL Image interface.

Do you want me to remove the exceptions with error messages and just let it ride all the way to the user?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK thanks for the explanation, in that case I think what you have is fine

@joecummings joecummings merged commit eb92658 into pytorch:main Sep 10, 2024
17 checks passed
@joecummings joecummings deleted the image-in-messages branch September 10, 2024 20:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants