Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Support Pixtral-HF on V1 #11409

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,17 @@ def _process_image_input(self,

assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)

if isinstance(image_features, torch.Tensor):
return self.multi_modal_projector(image_features)

feature_sizes = [
image_feature.shape[0] for image_feature in image_features
]

image_embeds = self.multi_modal_projector(torch.cat(image_features))
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,9 +1009,13 @@ def forward(
for img in pixel_values
]

patch_embeds = [
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
]
embed_sizes = [p.shape[1] for p in patch_embeds]

# flatten to a single sequence
patch_embeds = torch.cat(
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = torch.cat(patch_embeds, dim=1)
patch_embeds = self.ln_pre(patch_embeds)

# positional embeddings
Expand Down Expand Up @@ -1042,6 +1046,8 @@ def forward(
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
self.config.num_hidden_layers)

# squeeze dim 0 and split into separate tensors for each image
out = torch.split(torch.squeeze(out), embed_sizes)
return out

# (TODO) Add prefix argument for filtering out weights to be loaded
Expand Down
Loading