Skip to content

Commit

Permalink
[Bugfix] Fix M-RoPE position calculation when chunked prefill is enab…
Browse files Browse the repository at this point in the history
…led (#10388)

Signed-off-by: imkero <[email protected]>
  • Loading branch information
imkero authored Nov 16, 2024
1 parent b98d89e commit 361c29e
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 5 deletions.
136 changes: 132 additions & 4 deletions tests/models/decoder_only/vision_language/test_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
MODEL_HIDDEN_SIZE = 1536


def qwen2_vl_chat_template(*query):
Expand Down Expand Up @@ -230,7 +231,7 @@ def batch_make_video_embeddings(
return result


def run_test(
def run_embedding_input_test(
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]],
model: str,
Expand Down Expand Up @@ -326,7 +327,7 @@ def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
[],
) for image, prompt in zip(images, IMAGE_PROMPTS)]

run_test(
run_embedding_input_test(
vllm_runner,
inputs_per_case,
model,
Expand Down Expand Up @@ -371,7 +372,7 @@ def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets,
[],
)]

run_test(
run_embedding_input_test(
vllm_runner,
inputs_per_case,
model,
Expand Down Expand Up @@ -416,7 +417,134 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model,
[rescale_video_size(video, factor) for factor in size_factors],
) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)]

run_test(
run_embedding_input_test(
vllm_runner,
inputs_per_case,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=1,
tensor_parallel_size=1,
)


def run_chunked_prefill_test(
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
mm_limit: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Compare inference result between
chunked prefill disabled and chunked prefill enabled
"""

# NOTE:
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
task="generate",
max_model_len=4000,
max_num_seqs=4,
dtype=dtype,
limit_mm_per_prompt={
"image": mm_limit,
"video": mm_limit
},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend
) as vllm_model:

outputs_per_case = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images or None,
videos=videos or None)
for prompts, images, videos in inputs
]

with vllm_runner(
model,
task="generate",
max_model_len=4000,
max_num_seqs=4,
dtype=dtype,
limit_mm_per_prompt={
"image": mm_limit,
"video": mm_limit
},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enable_chunked_prefill=True,
# should be small enough to ensure prefilling is chunked
max_num_batched_tokens=32,
mm_processor_kwargs={
"max_pixels": 16 * 28 * 28,
}) as vllm_model_chunked:
outputs_per_case_chunked = [
vllm_model_chunked.generate_greedy_logprobs(
prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images or None,
videos=videos or None) for prompts, images, videos in inputs
]

for outputs, \
outputs_chunked \
in zip(outputs_per_case,
outputs_per_case_chunked):
check_logprobs_close(
outputs_0_lst=outputs,
outputs_1_lst=outputs_chunked,
name_0="non_chunked",
name_1="chunked",
)


@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [1])
@pytest.mark.parametrize("num_logprobs", [10])
def test_qwen2_vl_mrope_chunked_prefill(vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
"""
Test Qwen2-VL's chunked prefill with M-RoPE
"""
prompts = [
qwen2_vl_chat_template(IMAGE_PLACEHOLDER, prompt)
for prompt in example_prompts[:1]
]

# 1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs,
# so an image is included in the inputs
# 2. however, Qwen2-VL currently won't work properly
# when chunked prefill is enabled and there are some multi-modal inputs,
# here use a hacky way: provide a **zero-length** image to make it happy
#
# and finally we achieved:
# (1) chunked_prefill enabled; (2) M-RoPE works; to continue our tests
zero_len_image = {
"image_embeds": torch.empty((0, MODEL_HIDDEN_SIZE)),
"image_grid_thw": torch.tensor([[0, 0, 0]])
}
images = [zero_len_image] * len(prompts)

inputs_per_case: List[Tuple[List[str], PromptImageInput,
PromptVideoInput]] = [
(prompts, images, []),
]

run_chunked_prefill_test(
vllm_runner,
inputs_per_case,
model,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,7 @@ def get_input_positions(
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""

Expand Down Expand Up @@ -921,7 +922,7 @@ def get_input_positions(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:]
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()

Expand Down
1 change: 1 addition & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
context_len=inter_data.context_lens[seq_idx],
seq_len=inter_data.seq_lens[seq_idx],
)

seq_data.mrope_position_delta = mrope_position_delta
Expand Down

0 comments on commit 361c29e

Please sign in to comment.