-
Notifications
You must be signed in to change notification settings - Fork 457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Images in Messages #1504
Images in Messages #1504
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1504
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6fe061b with merge base 66590b4 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would mean that model transform just takes in a List[Message]
directly instead of a sample dictionary, or do we still plan to pass in sample["messages"]
to model transforms?
@@ -89,6 +89,8 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: | |||
role="system", content=self.new_system_prompt, masked=True, eot=True | |||
) | |||
) | |||
|
|||
# Add in image stuffs / load from file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am leaning towards handling image loading in the model transform and keeping a URL here, so we don't need to pass in a heavy PIL image everywhere Messages is used. We only load in the image when it's absolutely needed. For tokenize_messages, messages would be more lightweight since it won't contain the actual image.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re passing around PIL images, do you mean that there are implications for memory to doing something like this? Cause if we are just passing stuff by reference it shouldn't be an issue, right?
If there are a bunch of operations that we want to apply to messages that'd be more convenient to have the file path for, that'd be one thing. But it seems to me like most (nontrivial) functions we're gonna call on Messages containing images would be those that act on the image and not the path.
In that case I'd be inclined to have things as a PIL because it's the raw datatype of an image. It's analogous to the raw string we have in the content field of a Message with text type. I think {"text" -> str, "image" -> PIL} is a lot more natural for a user than {"text" -> str, "image" -> Path}. But lmk if I'm missing the point here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment in FAQs. Generally, I agree, but I'm making the tradeoff for "weight" of dependencies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now, it's possible to train and run inference without downloading extra things needed for multimodal like PIL and torchvision
I think Joe's point here is quite valid. I agree having it PIL image by default is more intuitive but as of now multimodal is not prevalent enough as a use case.
tests/assets/dog_on_skateboard.jpg
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
totally tubular
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Requests changes"
😭
@@ -55,7 +55,8 @@ def test_label_no_masking(self, load_dataset, tokenizer): | |||
model_transform=tokenizer, | |||
train_on_input=True, | |||
) | |||
input, labels, images = ds[0]["tokens"], ds[0]["labels"], ds[0]["images"] | |||
|
|||
input, labels = ds[0]["tokens"], ds[0]["labels"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should have DummyTokenizer load in images from the message (to replicate the model transform) and then check that the image tensor is returned here as expected. that way it would be a good e2e test of passing the images through the messages into the model transform
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow. Our tokenizers (model transforms) take in images through the messages, but they do not return them afterwards.
Looking at Flamingo transform for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we're using DummyTokenizer
as the substitute for model transform here. Really it should be called DummyModelTransform
. My point is we should test that images are passed through from the dataset to messages and are processed correctly. Right now with these changes we are not checking images at all, only the image special tokens.
if isinstance(content, str) | ||
else content | ||
) | ||
self.content = self._convert_to_list_of_dict(content) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice
{"type": "image"}, | ||
{ | ||
"type": "image", | ||
"content": sample[self._column_map["images"]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will be a PIL image. can image content still double as PIL image or path string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean it can - but it should be consistent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Damn, if this is coming in as a PIL image, then it might make sense to go ahead with the changes to load image into Message as a PIL image
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ha... thinking about it a bit more, I suppose there's no actual memory hit. if you're using multimodal datasets you will certainly need to load the image at some point. then again, changing my mind after five minutes hurts my pride just a little
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really want a message to be Union[Path, str, PIL.Image]
That seems like a recipe for confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ALRIGHT F IT, here's the plan:
I was a big ol dumb dumb. This stuff is gunna go. Messages will now hold PIL images. This will be loosely typed in Messages so that we don't need to actually import PIL and Messages can stay free of that dependency hell. Then, the onus for providing a proper PIL image to the Message class is on the dataset builder and the generation recipe. For generation, we could just have a separate one for multimodal or require PIL for everything. For dataset builder, everything is in the multimodal folder., which protects us from imports (thank for the forward thinking @RdoubleA)
This makes the most logical sense IMO.
data entry points (datasets, generation) load in the PIL image -> feed to message in proper format -> message is used by model transform however it wants.
Okay? Okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idk what just happened here but I'm on board
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ALRIGHT F IT
you right, I'm on board
I feel that |
|
||
|
||
# TODO: point to Flamingo model transform as an example | ||
def llava_instruct_dataset( | ||
model_transform: Transform, | ||
*, | ||
source: str = "liuhaotian/LLaVA-Instruct-150K", | ||
images_dir: str = "coco/", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love some input here. Right now this is a string b/c I'm anticipating usage from a config. However, the proper typing would be something like Path (Pathlib) b/c it would be much more tolerant to errant forward-slashes and plays better with combination of sub paths.
But then, I'll have to do some conversion on the backside to do that.
Thoughts? @RdoubleA @ebsmothers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we convert to Path on the backside but still maintain a string parameter, that should handle any trailing/incorrect slashes right? If so I would do that
PIL.Image.Image: The loaded image. | ||
""" | ||
# Hackily import PIL to avoid burdensome import in the main module | ||
# TODO: Fix this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't this import already pretty prevalent throughout the library? or are you avoiding some other issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only prevalent if doing MM. Right now, you can completely run all our text stuff without the need to import torchvision, PIL, etc...
I'd love to keep it that way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the fix here? Honestly I think this is an OK solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be putting multimodal stuff in a subfolder likely.
torchtune/data/_utils.py
Outdated
@@ -53,7 +90,7 @@ def split_text_by_image_tag(content: str, image_tag: str) -> List[Dict[str, str] | |||
"role": "system" | "user" | "assistant", | |||
"content": | |||
[ | |||
{"type": "image"}, | |||
{"type": "image", "content": "path/to/image1.png"}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs to be PIL image
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
u need to be a pil image 😡
torchtune/data/_utils.py
Outdated
>>> content = format_content_with_media( | ||
... "<|image|>hello <|image|>world", | ||
... image_tag="<|image|>", | ||
... images=[<PIL.Image.Image>, <PIL.Image.Image>"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... images=[<PIL.Image.Image>, <PIL.Image.Image>"] | |
... images=[<PIL.Image.Image>, <PIL.Image.Image>] |
return image | ||
|
||
|
||
def format_content_with_images( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe the name insert_images_in_content
is more clear on what the function is doing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, but it also does the splitting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
splitting so we can insert the images in the correct location :) but, your call
torchtune/data/_utils.py
Outdated
|
||
def format_content_with_images( | ||
content: str, *, image_tag: str, images: List["PIL.Image.Image"] | ||
) -> List[Dict[str, str]]: | ||
""" | ||
Given a raw text string, split by the specified ``image_tag`` | ||
and form into list of dictionaries to be used in the ``Message`` content |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: while you're here, could you describe what happens if image tag is not found in the content string? IIRC, it's just a no-op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, if image_tag and images are both none, it's a no-op. But otherwise, it'll error b/c if there's images, it expects an image_tag.
@@ -54,6 +55,8 @@ class LlavaInstructToMessages(Transform): | |||
new_system_prompt (Optional[str]): if specified, prepend a system message. This can | |||
serve as instructions to guide the model response. Setting this will OVERRIDE any system | |||
messages already present in the dataset. Default is None. | |||
images_dir (str): path to the directory containing the images. User is expected to download the COCO dataset. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if we should have a default. this should be explicitly defined by the user, otherwise they may use the builder/transform and forget to specify this and get confused when it break
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the default for something like this where it's for a very specific dataset.
pil_image = load_image( | ||
self.images_dir + sample[self._column_map["image"]] | ||
) | ||
content = format_content_with_images( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this reads really nicely now
47ff58f
to
5a74b30
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
heroic effort, thanks for this. I have no other concerns
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1504 +/- ##
===========================================
+ Coverage 27.22% 71.10% +43.88%
===========================================
Files 286 286
Lines 13828 13925 +97
===========================================
+ Hits 3764 9901 +6137
+ Misses 10064 4024 -6040 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left one comment on a potential change but otherwise it looks good to go.
from PIL import Image | ||
|
||
# If pointing to remote source, try to load to local | ||
if isinstance(image_loc, str) and image_loc.startswith("http"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be more robust to use urllib.parse to check if the string is a url. I don't think http is required to use urlopen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How? I can parse it, but urllib.parse parses it even if it's not a valid URL.
Limiting to http and https is restricting, sure, but we can relax when needed.
assert isinstance( | ||
content, list | ||
), f"content must be of type List[Dict[str, Any]], got {content}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we don't wanna do it here, but is there somewhere that we can/should validate that we have a PIL image when type="image"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validate_messages?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argh, I cannot do this without a nested import of PIL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK fine to leave it as is then
if role == "system" and self.new_system_prompt is not None: | ||
continue | ||
content = split_text_by_image_tag(message["value"], "<image>") | ||
if role == "user": | ||
image_path = sample[self._column_map["image"]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this will be present in every sample in the dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup
PIL.Image.Image: The loaded image. | ||
""" | ||
# Hackily import PIL to avoid burdensome import in the main module | ||
# TODO: Fix this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the fix here? Honestly I think this is an OK solution
raise ValueError(f"Failed to load image from {image_loc}") from e | ||
|
||
# Open the local image as a PIL image | ||
try: | ||
image = Image.open(image_loc) | ||
except Exception as e: | ||
raise ValueError(f"Failed to open image as PIL Image from {image_loc}") from e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: are these really ValueErrors? (Or are you doing the thing I also do and lazily using ValueError
as a catch-all for things that are not actually ValueErrors)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WRT the first one, no it could be several different errors. Trying to call a remote URL always has plenty of exception types. I could just leave the error as is and let the user deal with it, but that seems messy.
WRT the second value error, more often than not, it will be a ValueError as the provided value is unable to be opened by the PIL Image interface.
Do you want me to remove the exceptions with error messages and just let it ride all the way to the user?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK thanks for the explanation, in that case I think what you have is fine
Co-authored-by: ebsmothers <[email protected]>
Put image content in
Message
(and other suspicious changes)Background
No snappy intro, I'm just tired.
This all started in considering the inference UX for multimodal models. The ideal state would be something like this in the config:
Then the user can call the generate recipe:
But, what is this getting at?
The concept of a "Message" should contain all the information needed for an interaction with a model INCLUDING images. Before, images were split off from Messages and a placeholder was put in the Message object so the model knew where to inject the content:
After these changes, the model now has ALL the information it needs within the message content itself:
The core of the change is relatively simple; however, we already backed the split images work into all our multimodal stuff, so that all had to be reworked as well. See below for a full changelog.
Changelog
get_media
function to easily get all the image content from a MessageTesting
FAQs
Why did you make it possible to load from remote or local sources? This is a feature usually demonstrated for inference like in Phi3 Vision and Qwen2 vision. We don't want to force users to have to download the image to device to start, so this handles both.
Why not convert to PIL image within Message?
I wanted to do this originally; however, Message is a class used for text and multimodal models. Right now, it's possible to train and run inference without downloading extra things needed for multimodal like PIL and torchvision. This keeps the library lightweight. In keeping with that, if we make the specific multimodal model transforms load and use PIL, we keep this usage separate. I'd be open to changing if PIL and Multimodal becomes so important that it's used everywhere, but right now it's not.I did it, sue me.