Skip to content

Commit

Permalink
Fix mantis processor
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 24, 2024
1 parent 9cd38b1 commit 14dcdd5
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 37 deletions.
4 changes: 0 additions & 4 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ steps:
- pytest -v -s async_engine # AsyncLLMEngine
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
- pytest -v -s test_inputs.py
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git # For multimodal tests
- pytest -v -s multimodal
- pytest -v -s test_utils.py # Utils
- pytest -v -s worker # Worker
Expand Down Expand Up @@ -370,7 +369,6 @@ steps:
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model
Expand All @@ -386,7 +384,6 @@ steps:
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model'
# HACK - run phi3v tests separately to sidestep this transformers bug
Expand All @@ -403,7 +400,6 @@ steps:
- vllm/
- tests/models/decoder_only/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=1) and not core_model and not quant_model'

# This test is used only in PR development phase to test individual models and should never run on main
Expand Down
3 changes: 1 addition & 2 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -748,8 +748,7 @@ vLLM currently only supports adding LoRA to the language backbone of multimodal
```

```{note}
To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo ({code}`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`)
and pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
```

```{note}
Expand Down
3 changes: 3 additions & 0 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def get_hf_processor(

merged_kwargs = {**base_kwargs, **kwargs}

if isinstance(typ, type):
merged_kwargs["processor_cls"] = typ

hf_processor = cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,
Expand Down
89 changes: 70 additions & 19 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
ProcessorMixin, SiglipVisionConfig)
SiglipVisionConfig)
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor

Expand All @@ -20,11 +20,13 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
PromptReplacement,
full_groupby_modality)
from vllm.sequence import IntermediateTensors

from .clip import (CLIPVisionModel, dummy_image_for_clip,
Expand Down Expand Up @@ -131,12 +133,12 @@ def _call_hf_processor(
mm_kwargs=mm_kwargs,
)

if "pixel_values" in processed_outputs:
# NOTE: pixel_values=None for MLlavaProcessor
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
images = mm_data["images"]
assert isinstance(images, list)

pixel_values = processed_outputs["pixel_values"]

if isinstance(self._get_hf_processor(), PixtralProcessor):
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
Expand Down Expand Up @@ -575,22 +577,71 @@ def load_weights(self, weights: Iterable[Tuple[str,

class MantisMultiModalProcessor(LlavaMultiModalProcessor):

def _get_hf_processor(self) -> ProcessorMixin:
try:
from mantis.models.mllava import MLlavaProcessor
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"You need to `pip install "
"git+https://github.com/TIGER-AI-Lab/Mantis.git` "
"to use this model") from exc
def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaProcessor)

processor = MLlavaProcessor.from_pretrained(
self.ctx.model_config.tokenizer)
assert isinstance(processor, MLlavaProcessor)
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
max_image_tokens = get_max_llava_image_tokens(self.ctx)

result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)

mm_items = self._get_mm_items(mm_data)
mm_item_counts = mm_items.get_item_counts()
mm_kwargs = result["mm_kwargs"]

processor.image_token = "<image>" # type: ignore
# We reimplement the functionality of MLlavaProcessor from
# https://github.com/TIGER-AI-Lab/Mantis.git
def get_replacement_mantis(item_idx: int):
return "".join([
f"(image {item_idx+1}: <Image>", # 7 tokens
"<image>" * max_image_tokens,
"</Image>)", # 3 tokens
])

return processor
mantis_repls = self._bind_prompt_replacements([
PromptReplacement(
modality="image",
target=[image_token_id] * max_image_tokens,
replacement=get_replacement_mantis,
)
])

prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
result["prompt_token_ids"],
mantis_repls,
mm_item_counts,
)

unbound_orig_repls = self._get_prompt_replacements(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
orig_repls = self._bind_prompt_replacements(unbound_orig_repls)

all_placeholders = self._find_placeholders(orig_repls, prompt_ids,
mm_item_counts)
assert len(all_placeholders) == mm_item_counts.get("image", 0)

mm_placeholders = {
modality: [item.to_range() for item in items]
for modality, items in full_groupby_modality(all_placeholders)
}

return MultiModalInputsV2(
type="multimodal",
prompt=prompt_text,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders,
)


# To use this model, please use
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,11 @@ def apply(
# Only <|image|> tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
result["mm_placeholders"] = {
k: [
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for k, ps in result["mm_placeholders"].items()
for modality, ps in result["mm_placeholders"].items()
}

return result
Expand Down
5 changes: 3 additions & 2 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,9 @@ def from_hf_inputs(
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
# We assume that those fields are not used in vLLM
items_by_key = {
key: config.build_items(key, hf_inputs[key])
for key, config in config_by_key.items() if key in hf_inputs
key: config.build_items(key, batch)
for key, config in config_by_key.items()
if (batch := hf_inputs.get(key)) is not None
}

return MultiModalKwargs.from_items_by_key(
Expand Down
10 changes: 5 additions & 5 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,18 +1063,18 @@ def apply(
hf_processor_mm_kwargs,
)

prompt_repls = self._get_prompt_replacements(
unbound_prompt_repls = self._get_prompt_replacements(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
prompt_repls = self._bind_prompt_replacements(unbound_prompt_repls)

# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
mm_item_counts = mm_items.get_item_counts()
all_placeholders = self._find_placeholders(all_prompt_repls,
prompt_ids, mm_item_counts)
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
mm_item_counts)

if all_placeholders:
tokenizer = self._get_tokenizer()
Expand All @@ -1086,7 +1086,7 @@ def apply(
all_placeholders,
) = self._apply_prompt_replacements(
prompt_ids,
all_prompt_repls,
prompt_repls,
mm_item_counts,
)

Expand Down
12 changes: 9 additions & 3 deletions vllm/transformers_utils/processor.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
from functools import lru_cache
from typing import Any, cast

from transformers.processing_utils import ProcessorMixin


def get_processor(
processor_name: str,
*args: Any,
trust_remote_code: bool = False,
processor_cls: type[ProcessorMixin] = ProcessorMixin,
**kwargs: Any,
):
"""Load a processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor
from transformers.processing_utils import ProcessorMixin

processor_factory = (AutoProcessor
if processor_cls == ProcessorMixin else processor_cls)

try:
processor = AutoProcessor.from_pretrained(
processor = processor_factory.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
**kwargs,
)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
Expand Down

0 comments on commit 14dcdd5

Please sign in to comment.