Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][VLM] Implement merged multimodal processor for Mllama #11427

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
6 changes: 4 additions & 2 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
if TYPE_CHECKING:
from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs,
MultiModalPlaceholderDict)
from vllm.multimodal.inputs import MultiModalInputsV2
from vllm.multimodal.inputs import (MultiModalEncDecInputs,
MultiModalInputsV2)


class TextPrompt(TypedDict):
Expand Down Expand Up @@ -229,7 +230,8 @@ class EncoderDecoderInputs(TypedDict):
"""The inputs for the decoder portion."""


SingletonInputs = Union[TokenInputs, "MultiModalInputsV2"]
SingletonInputs = Union[TokenInputs, "MultiModalInputsV2",
"MultiModalEncDecInputs"]
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
Expand Down
27 changes: 23 additions & 4 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import asyncio
from typing import List, Mapping, Optional, Union
from typing import List, Mapping, Optional, Union, cast

from typing_extensions import assert_never

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2
from vllm.multimodal.processing import (MultiModalDataDict,
MultiModalEncDecInputs,
MultiModalInputsV2)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_info_once, print_warning_once
Expand Down Expand Up @@ -512,12 +514,29 @@ def _process_encoder_decoder_prompt(
request_id=request_id,
)
else:
encoder_inputs = self._prompt_to_llm_inputs(
inputs = self._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
if inputs["type"] == "multimodal":
assert ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs)
inputs = cast(MultiModalEncDecInputs, inputs)
encoder_inputs = token_inputs(
prompt=inputs["encoder_prompt"],
prompt_token_ids=inputs["encoder_prompt_token_ids"],
)
decoder_inputs = MultiModalInputsV2(
type="multimodal",
prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_placeholders=inputs["mm_placeholders"],
)
else:
encoder_inputs = inputs

decoder_inputs = None
decoder_inputs = None

return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)

Expand Down
14 changes: 11 additions & 3 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def dummy_data_for_profiling(
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import cached_get_tokenizer

if mm_registry.has_processor(model_config):
Expand All @@ -335,9 +336,16 @@ def dummy_data_for_profiling(
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_max_tokens = mm_registry.get_max_tokens_by_modality(
model_config)

dummy_data = processor.get_dummy_data(seq_len, mm_counts,
mm_max_tokens)
if is_encoder_data:
assert isinstance(processor, EncDecMultiModalProcessor)
dummy_data = processor.get_dummy_data(
seq_len,
mm_counts,
mm_max_tokens,
is_encoder_data=is_encoder_data)
else:
dummy_data = processor.get_dummy_data(seq_len, mm_counts,
mm_max_tokens)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
Expand Down
225 changes: 78 additions & 147 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,19 @@
import transformers.models.mllama.configuration_mllama as config_mllama
from PIL import Image
from torch import nn
from transformers import BatchFeature, MllamaConfig
from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast)
from transformers.models.mllama.image_processing_mllama import (
get_optimal_tiled_canvas)
from transformers.models.mllama.processing_mllama import (
get_cross_attention_token_mask)
MllamaProcessor, get_cross_attention_token_mask)

import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
InputContext, TokenInputs, token_inputs)
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand All @@ -51,8 +49,9 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SequenceData
from vllm.utils import is_list_of
from vllm.multimodal.processing import (EncDecMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)

from .clip import CLIPMLP
from .interfaces import SupportsMultiModal
Expand All @@ -78,151 +77,86 @@ class MllamaImagePixelInputs(TypedDict):
# TODO: support LlamaImageEmbeddingInputs


def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
num_images = 0
for token_id in prompt_token_ids[::-1]:
if token_id == MLLAMA_IMAGE_TOKEN_ID:
num_images += 1
elif num_images > 0:
break
return num_images


def input_processor_for_mllama(
ctx: InputContext,
inputs: EncoderDecoderInputs,
) -> EncoderDecoderInputs:
# Example input to processor:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000],
# },
# }

# move encoder prompt to decoder
dec_inputs = TokenInputs(**inputs["encoder"])

multi_modal_data = dec_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
# text-only
return EncoderDecoderInputs(
encoder=token_inputs([]),
decoder=dec_inputs,
)

image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_data = [image_data]

assert is_list_of(image_data, Image.Image)

# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tiles for those images.
num_decode_images = _get_num_image_in_last_group(
dec_inputs["prompt_token_ids"])

hf_config = ctx.model_config.hf_config
vision_config = hf_config.vision_config

num_tiles = 0
for image in image_data[::-1]:
width, height = image.size
tile_size = vision_config.image_size
canvas_height, canvas_width = get_optimal_tiled_canvas(
image_height=height,
image_width=width,
max_image_tiles=vision_config.max_num_tiles,
tile_size=tile_size,
)
num_tiles_height = canvas_height // tile_size
num_tiles_width = canvas_width // tile_size
num_tiles += num_tiles_height * num_tiles_width
num_decode_images -= 1
if num_decode_images == 0:
break

# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
assert vision_config.image_size % 14 == 0, \
"chunk size should be multiple of 14"
token_per_chunk = (vision_config.image_size // 14)**2 + 1
num_tokens = num_tiles * token_per_chunk

# Example output from processor:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128256, ..., 128256],
# 'prompt': '<|image|><|image|>...<|image|>',
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# }
return EncoderDecoderInputs(
encoder=token_inputs(
prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens,
prompt=MLLAMA_IMAGE_TOKEN * num_tokens,
multi_modal_data=multi_modal_data,
),
decoder=dec_inputs,
)


def get_max_mllama_image_tokens(ctx: InputContext) -> int:
hf_config = ctx.model_config.hf_config
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
return hf_config.vision_config.max_num_tiles * token_per_chunk


def dummy_decoder_seq_data(seq_len: int, num_images: int):
# <|image|> * num_images + 0 * (seq_len - num_images)
assert seq_len >= num_images, \
"seq_len should be greater than or equal to num_images"

return SequenceData.from_prompt_token_counts(
(MLLAMA_IMAGE_TOKEN_ID, num_images),
(0, seq_len - num_images),
)

class MllamaMultiModalProcessor(EncDecMultiModalProcessor):

def dummy_encoder_seq_data(ctx: InputContext, num_images: int):
num_tokens = get_max_mllama_image_tokens(ctx) * num_images

return SequenceData.from_prompt_token_counts(
(MLLAMA_IMAGE_TOKEN_ID, num_tokens))


def dummy_image(num_images: int, ):
width = height = 1024
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}


def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
return DummyData(dummy_decoder_seq_data(seq_len, num_images))
def _get_hf_processor(self) -> MllamaProcessor:
return self.ctx.get_hf_processor(MllamaProcessor)

def _call_hf_processor(
self,
hf_processor: MllamaProcessor,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
# The MllamaProcessor calling drop `num_tiles` from image_processor,
# while `num_tiles` is essential for forwarding in vLLM implementation.
# Therefore, we use image_processor calling to keep `num_tiles`.
image_processor = hf_processor.image_processor
image_features = image_processor(**processor_data)

tokenizer = self._get_tokenizer()
encoding = tokenizer(prompt,
add_special_tokens=False,
return_tensors="pt")
data = dict(**encoding, **image_features)

return BatchFeature(data=data, tensor_type="pt")

def _create_encoder_prompt(self, prompt: str):
hf_processor = self._get_hf_processor()
image_token = hf_processor.image_token
num_images = prompt.count(image_token)
return image_token * num_images

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
vision_config = self.ctx.get_hf_config(MllamaConfig).vision_config
assert vision_config.image_size % 14 == 0, (
"chunk size should be multiple of 14")
token_per_chunk = (vision_config.image_size // 14)**2 + 1
hf_processor = self._get_hf_processor()
image_token_id = hf_processor.image_token_id

def get_replacement_mllama(item_idx):
num_tile = hf_inputs["num_tiles"][0][item_idx]
num_tokens = num_tile * token_per_chunk
return [image_token_id] * num_tokens

return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_mllama,
)
]

def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
return DummyData(dummy_encoder_seq_data(ctx, num_images),
dummy_image(num_images))
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts["image"]
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token

width = height = 1024
image = Image.new("RGB", (width, height), color=0)

return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data={"image": image},
mm_processor_kwargs={},
)


def _prepare_aspect_ratio_attention_mask(
Expand Down Expand Up @@ -1097,11 +1031,8 @@ def forward(
return hidden_states


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama)
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
Expand Down
16 changes: 16 additions & 0 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,19 @@ class MultiModalInputsV2(TypedDict):
For each modality, information about the placeholder tokens in
:code:`prompt_token_ids`.
"""


class MultiModalEncDecInputs(MultiModalInputsV2):
"""
Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor`
ready to be passed to vLLM internals.
"""

encoder_prompt: str
"""The processed encoder prompt text."""

encoder_prompt_token_ids: List[int]
"""The processed token IDs of the encoder prompt."""

encoder_token_type_ids: NotRequired[List[int]]
"""The token type IDs of the encoder prompt."""
Loading
Loading