Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Support caching in merged multi-modal processor #11396

Open
wants to merge 75 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
faa9b84
Refactor multi-modal processor to support caching
DarkLight1337 Dec 19, 2024
9711a15
Clean up
DarkLight1337 Dec 19, 2024
29e3fcd
Fix cached result being mutated
DarkLight1337 Dec 19, 2024
ab64e85
Rename
DarkLight1337 Dec 19, 2024
81215a2
Fix docs
DarkLight1337 Dec 19, 2024
cf52b3b
Fix a typo
DarkLight1337 Dec 19, 2024
a4a8eb9
Fix unhandled sampling rate in initialization
DarkLight1337 Dec 19, 2024
c48f7c5
format
DarkLight1337 Dec 19, 2024
b84ff42
Change the delimiter
DarkLight1337 Dec 19, 2024
c3f1bde
Fix extra dimension
DarkLight1337 Dec 19, 2024
32e5197
Update
DarkLight1337 Dec 19, 2024
7264d4e
Use the inner processor to enable fine-grained caching
DarkLight1337 Dec 20, 2024
02ea829
Make the cache optional
DarkLight1337 Dec 20, 2024
b981a9d
Fix invalid kwargs being passed to tokenizer
DarkLight1337 Dec 20, 2024
5dde7d0
Fix Phi3V prompt replacement
DarkLight1337 Dec 20, 2024
7339ab8
Refine
DarkLight1337 Dec 20, 2024
509411d
Enable fine-grained caching for audio models
DarkLight1337 Dec 20, 2024
c0454f5
Add fallback
DarkLight1337 Dec 20, 2024
d50ef03
Fix typo
DarkLight1337 Dec 20, 2024
81f7d61
Fix video processor for Qwen2-VL
DarkLight1337 Dec 20, 2024
13eede3
Merge branch 'main' into mm-processor-cache
DarkLight1337 Dec 20, 2024
affbc5c
Fix a bunch of type errors
DarkLight1337 Dec 20, 2024
b4ddfb1
Fix qwen2-vl
DarkLight1337 Dec 20, 2024
4b3db32
Fix
DarkLight1337 Dec 20, 2024
dafbc7f
Simplify Pixtral-HF
DarkLight1337 Dec 21, 2024
38aaff8
Cleanup
DarkLight1337 Dec 21, 2024
5fcb5d6
Fix Pixtral-HF
DarkLight1337 Dec 21, 2024
f86e148
Enable caching outside the processing loop
DarkLight1337 Dec 21, 2024
337f0d2
Make debugging easier
DarkLight1337 Dec 21, 2024
c01d38a
Update
DarkLight1337 Dec 21, 2024
84f02fb
Fix ultravox
DarkLight1337 Dec 21, 2024
9f417c2
Revert some unnecessary changes
DarkLight1337 Dec 21, 2024
00b765b
Merge branch 'main' into mm-fields
DarkLight1337 Dec 22, 2024
2ed431e
Add test and fix some issues
DarkLight1337 Dec 23, 2024
baaf551
Update
DarkLight1337 Dec 23, 2024
f5dbcb8
Fix
DarkLight1337 Dec 23, 2024
afd3f4f
Rework
DarkLight1337 Dec 23, 2024
6172450
Rename the test
DarkLight1337 Dec 23, 2024
416943d
Update count
DarkLight1337 Dec 23, 2024
86f2786
Rename
DarkLight1337 Dec 23, 2024
f5b6214
Some fixes
DarkLight1337 Dec 23, 2024
8a68e87
Cleanup
DarkLight1337 Dec 23, 2024
ab7e84b
Skip unspecified fields
DarkLight1337 Dec 23, 2024
9f2cdaa
Fix equality checking
DarkLight1337 Dec 23, 2024
d11e833
Consolidate common code
DarkLight1337 Dec 23, 2024
5fee280
Improve error message
DarkLight1337 Dec 23, 2024
6182fd6
Cleanup
DarkLight1337 Dec 23, 2024
e1214cf
Fix Pixtral-HF
DarkLight1337 Dec 23, 2024
c717bce
Fix missing mm_count key
DarkLight1337 Dec 23, 2024
023890e
Fix qwen2-vl
DarkLight1337 Dec 23, 2024
b5e5b8a
Fix Qwen2-VL
DarkLight1337 Dec 23, 2024
cf24a1b
Fix Qwen2-VL and Qwen2-Audio
DarkLight1337 Dec 23, 2024
73271e9
Debug Phi3V
DarkLight1337 Dec 23, 2024
e30deec
Consolidate common code
DarkLight1337 Dec 23, 2024
ea6f8b5
Try to fix Phi3V and Ultravox
DarkLight1337 Dec 23, 2024
10ae755
Remove benchmark
DarkLight1337 Dec 23, 2024
85c5e2c
Fix token mismatch in Phi3V and Ultravox
DarkLight1337 Dec 23, 2024
4873ff8
Update max image tokens
DarkLight1337 Dec 23, 2024
4dbb5a3
Strictly check the number of placeholder tokens
DarkLight1337 Dec 23, 2024
6dbae81
Fix doc failure
DarkLight1337 Dec 23, 2024
fb51c9b
Test and fix Mantis processor
DarkLight1337 Dec 24, 2024
91cbd63
Fix embedding inputs
DarkLight1337 Dec 24, 2024
6bee6ba
Update entrypoints tests
DarkLight1337 Dec 24, 2024
cfa2ce8
Merge branch 'main' into mm-fields
DarkLight1337 Dec 24, 2024
fa54292
Clean up
DarkLight1337 Dec 24, 2024
cbf79be
Avoid extra placeholder in phi3v
DarkLight1337 Dec 24, 2024
9cd38b1
Fix OOM
DarkLight1337 Dec 24, 2024
14dcdd5
Fix mantis processor
DarkLight1337 Dec 24, 2024
b8bd2d4
Merge branch 'main' into mm-fields
DarkLight1337 Dec 24, 2024
5045d93
Remove redundant code
DarkLight1337 Dec 24, 2024
4cac998
Still need Mantis repo for testing
DarkLight1337 Dec 24, 2024
e8afd10
Merge branch 'main' into mm-fields
DarkLight1337 Dec 25, 2024
93bba0a
Fix incorrect max image tokens (Updated in #11258)
DarkLight1337 Dec 25, 2024
ea9f888
Also cache by model ID
DarkLight1337 Dec 25, 2024
58747f6
Format
DarkLight1337 Dec 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def linkcode_resolve(domain, info):

# Mock out external dependencies here, otherwise the autodoc pages may be blank.
autodoc_mock_imports = [
"blake3",
"compressed_tensors",
"cpuinfo",
"cv2",
Expand All @@ -178,7 +179,7 @@ def linkcode_resolve(domain, info):
"tensorizer",
"pynvml",
"outlines",
"xgrammar,"
"xgrammar",
"librosa",
"soundfile",
"gguf",
Expand Down
3 changes: 1 addition & 2 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -750,8 +750,7 @@ vLLM currently only supports adding LoRA to the language backbone of multimodal
```

```{note}
To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo ({code}`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`)
and pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
Copy link
Member Author

@DarkLight1337 DarkLight1337 Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now use HF's LlavaProcessor + our own prompt replacements to replicate the logic of MLlavaProcessor, so users don't have to install their GitHub anymore.

```

```{note}
Expand Down
4 changes: 2 additions & 2 deletions tests/entrypoints/openai/test_vision_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 3072
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 765
assert embeddings.usage.total_tokens == 765
assert embeddings.usage.prompt_tokens == 764
assert embeddings.usage.total_tokens == 764
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_max_qwen2_vl_image_tokens():


@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
({}, 1225),
({}, 16384),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
Expand Down
4 changes: 3 additions & 1 deletion tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@
vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output,
num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[large_gpu_mark(min_gb=48)],
),
"glm4": VLMTestInfo(
models=["THUDM/glm-4v-9b"],
Expand All @@ -212,7 +213,7 @@
dtype="bfloat16",
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
patch_hf_runner=model_utils.glm_patch_hf_runner,
marks=[large_gpu_mark(min_gb=48)],
marks=[large_gpu_mark(min_gb=32)],
),
"h2ovl": VLMTestInfo(
models = [
Expand Down Expand Up @@ -261,6 +262,7 @@
dtype="bfloat16",
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
),
"llava_next": VLMTestInfo(
models=["llava-hf/llava-v1.6-mistral-7b-hf"],
Expand Down
210 changes: 204 additions & 6 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from functools import partial
from typing import cast

import numpy as np
import pytest

from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
find_text_matches, find_token_matches,
iter_placeholders, iter_token_matches,
from PIL import Image

from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
_PlaceholderInfo, find_text_matches,
find_token_matches, iter_placeholders,
iter_token_matches,
replace_text_matches,
replace_token_matches)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby

Expand Down Expand Up @@ -457,6 +465,7 @@ def test_find_replace_tokens(
),
]
)
# yapf: enable
def test_iter_placeholders(
repl_by_key,
prompt,
Expand All @@ -475,11 +484,200 @@ def test_iter_placeholders(
prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
{key: 3 for key in repl_by_key},
))
{key: 3
for key in repl_by_key},
))

# Only displayed on error
print("result:", result)

# Manually constructed results
assert result == expected


def _rand_img(rng: np.random.RandomState, min_wh: int, max_wh: int):
w, h = rng.randint(min_wh, max_wh, size=(2, ))
arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8)
return Image.fromarray(arr)


def _rand_video(
rng: np.random.RandomState,
min_frames: int,
max_frames: int,
min_wh: int,
max_wh: int,
):
# Temporary fix. Qwen2-VL video processor fails on video of shape
# (b, 199, 178, 3) where b in (3, 5, 7)
num_frames = rng.randint(min_frames, max_frames)
num_frames = (num_frames // 2) * 2

w, h = rng.randint(min_wh, max_wh, size=(2, ))
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)


def _rand_audio(
rng: np.random.RandomState,
min_len: int,
max_len: int,
sr: int,
):
audio_len = rng.randint(min_len, max_len)
return rng.rand(audio_len), sr


def _test_processing_cache_correctness(
model_id: str,
modalities: set[str],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
else:
hf_overrides = {}

model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=True,
seed=0,
dtype="float16",
revision=None,
hf_overrides=hf_overrides,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)

processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30)

baseline_processor = processor_factory(ctx, cache=None)
cached_processor = processor_factory(ctx, cache=cache)

rng = np.random.RandomState(0)

input_to_hit = {
"image": Image.new("RGB", size=(128, 128)),
"video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
"audio": (np.zeros((512, )), 16000),
}
input_factory = {
"image":
partial(_rand_img, rng, min_wh=128, max_wh=256),
"video":
partial(_rand_video,
rng,
min_frames=2,
max_frames=8,
min_wh=128,
max_wh=256),
"audio":
partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000),
}
input_max_count = {
"image": 3,
"video": 3,
"audio": 3,
}

for batch_idx in range(num_batches):
mm_data = {
k:
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
for _ in range(rng.randint(input_max_count[k]))]
for k in modalities
}

mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text

# Drop unnecessary keys and test single -> multi conversion
if rng.rand() < simplify_rate:
for k in list(mm_data.keys()):
if not mm_data[k]:
del mm_data[k]
elif len(mm_data[k]) == 1:
mm_data[k] = mm_data[k][0]

baseline_result = baseline_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_result = cached_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert baseline_result == cached_result, (
f"Failed ({batch_idx=}, {mm_data=})")


# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
("llava-hf/llava-1.5-7b-hf", {"image"}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image"}),
("mistral-community/pixtral-12b", {"image"}),
("Qwen/Qwen2-VL-2B-Instruct", {"image", "video"}),
("Qwen/Qwen2-Audio-7B-Instruct", {"audio"}),
("fixie-ai/ultravox-v0_3", {"audio"}),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness(
model_id: str,
modalities: set[str],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
_test_processing_cache_correctness(
model_id,
modalities,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,
)


# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
("microsoft/Phi-3-vision-128k-instruct", {"image"}),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness_phi3v(
model_id: str,
modalities: set[str],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
# HACK - this is an attempted workaround for the following bug
# https://github.com/huggingface/transformers/issues/34307
from transformers import AutoImageProcessor # noqa: F401
from transformers import AutoProcessor # noqa: F401

AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)

_test_processing_cache_correctness(
model_id,
modalities,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,
)
21 changes: 10 additions & 11 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def get_hf_processor(

merged_kwargs = {**base_kwargs, **kwargs}

if isinstance(typ, type):
merged_kwargs["processor_cls"] = typ

hf_processor = cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,
Expand Down Expand Up @@ -132,33 +135,29 @@ def get_hf_processor(
def call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
inference_kwargs: Mapping[str, object],
data: Mapping[str, object],
kwargs: Optional[Mapping[str, object]] = None,
) -> BatchFeature:
assert callable(hf_processor)

if kwargs is None:
kwargs = {}

base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}

merged_kwargs = resolve_mm_processor_kwargs(
base_kwargs,
inference_kwargs,
kwargs,
hf_processor,
requires_kw_only=False,
allow_var_kwargs=True,
)

try:
return hf_processor(
text=prompt,
**processor_data,
**merged_kwargs,
return_tensors="pt",
)
return hf_processor(**data, **merged_kwargs, return_tensors="pt")
except Exception as exc:
data = dict(text=prompt, **processor_data)
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={merged_kwargs}")

Expand Down
Loading
Loading