diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py index 718c675b86fb4..71b6ba4dca435 100644 --- a/tests/models/decoder_only/vision_language/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py @@ -18,6 +18,7 @@ IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>" VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>" +MODEL_HIDDEN_SIZE = 1536 def qwen2_vl_chat_template(*query): @@ -230,7 +231,7 @@ def batch_make_video_embeddings( return result -def run_test( +def run_embedding_input_test( vllm_runner: Type[VllmRunner], inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]], model: str, @@ -326,7 +327,7 @@ def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, [], ) for image, prompt in zip(images, IMAGE_PROMPTS)] - run_test( + run_embedding_input_test( vllm_runner, inputs_per_case, model, @@ -371,7 +372,7 @@ def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets, [], )] - run_test( + run_embedding_input_test( vllm_runner, inputs_per_case, model, @@ -416,7 +417,134 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model, [rescale_video_size(video, factor) for factor in size_factors], ) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)] - run_test( + run_embedding_input_test( + vllm_runner, + inputs_per_case, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) + + +def run_chunked_prefill_test( + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + mm_limit: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Compare inference result between + chunked prefill disabled and chunked prefill enabled + """ + + # NOTE: + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + task="generate", + max_model_len=4000, + max_num_seqs=4, + dtype=dtype, + limit_mm_per_prompt={ + "image": mm_limit, + "video": mm_limit + }, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend + ) as vllm_model: + + outputs_per_case = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images or None, + videos=videos or None) + for prompts, images, videos in inputs + ] + + with vllm_runner( + model, + task="generate", + max_model_len=4000, + max_num_seqs=4, + dtype=dtype, + limit_mm_per_prompt={ + "image": mm_limit, + "video": mm_limit + }, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_chunked_prefill=True, + # should be small enough to ensure prefilling is chunked + max_num_batched_tokens=32, + mm_processor_kwargs={ + "max_pixels": 16 * 28 * 28, + }) as vllm_model_chunked: + outputs_per_case_chunked = [ + vllm_model_chunked.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images or None, + videos=videos or None) for prompts, images, videos in inputs + ] + + for outputs, \ + outputs_chunked \ + in zip(outputs_per_case, + outputs_per_case_chunked): + check_logprobs_close( + outputs_0_lst=outputs, + outputs_1_lst=outputs_chunked, + name_0="non_chunked", + name_1="chunked", + ) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [1]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_qwen2_vl_mrope_chunked_prefill(vllm_runner, example_prompts, + model: str, dtype: str, + max_tokens: int, + num_logprobs: int) -> None: + """ + Test Qwen2-VL's chunked prefill with M-RoPE + """ + prompts = [ + qwen2_vl_chat_template(IMAGE_PLACEHOLDER, prompt) + for prompt in example_prompts[:1] + ] + + # 1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs, + # so an image is included in the inputs + # 2. however, Qwen2-VL currently won't work properly + # when chunked prefill is enabled and there are some multi-modal inputs, + # here use a hacky way: provide a **zero-length** image to make it happy + # + # and finally we achieved: + # (1) chunked_prefill enabled; (2) M-RoPE works; to continue our tests + zero_len_image = { + "image_embeds": torch.empty((0, MODEL_HIDDEN_SIZE)), + "image_grid_thw": torch.tensor([[0, 0, 0]]) + } + images = [zero_len_image] * len(prompts) + + inputs_per_case: List[Tuple[List[str], PromptImageInput, + PromptVideoInput]] = [ + (prompts, images, []), + ] + + run_chunked_prefill_test( vllm_runner, inputs_per_case, model, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 63ceec63e8317..b01e4c61fe101 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -847,6 +847,7 @@ def get_input_positions( vision_end_token_id: int, spatial_merge_size: int, context_len: int = 0, + seq_len: Optional[int] = None, ) -> Tuple[List[List[int]], int]: """Get mrope input positions and delta value.""" @@ -921,7 +922,7 @@ def get_input_positions( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:] + llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 042f9f07eace6..22ee3f9f863e4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -700,6 +700,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, spatial_merge_size=hf_config.vision_config. spatial_merge_size, context_len=inter_data.context_lens[seq_idx], + seq_len=inter_data.seq_lens[seq_idx], ) seq_data.mrope_position_delta = mrope_position_delta