Skip to content

Commit 361c29e

Browse files
authored
[Bugfix] Fix M-RoPE position calculation when chunked prefill is enabled (#10388)
Signed-off-by: imkero <kerorek@outlook.com>
1 parent b98d89e commit 361c29e

File tree

3 files changed

+135
-5
lines changed

3 files changed

+135
-5
lines changed

tests/models/decoder_only/vision_language/test_qwen2_vl.py

+132-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
2020
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
21+
MODEL_HIDDEN_SIZE = 1536
2122

2223

2324
def qwen2_vl_chat_template(*query):
@@ -230,7 +231,7 @@ def batch_make_video_embeddings(
230231
return result
231232

232233

233-
def run_test(
234+
def run_embedding_input_test(
234235
vllm_runner: Type[VllmRunner],
235236
inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]],
236237
model: str,
@@ -326,7 +327,7 @@ def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
326327
[],
327328
) for image, prompt in zip(images, IMAGE_PROMPTS)]
328329

329-
run_test(
330+
run_embedding_input_test(
330331
vllm_runner,
331332
inputs_per_case,
332333
model,
@@ -371,7 +372,7 @@ def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets,
371372
[],
372373
)]
373374

374-
run_test(
375+
run_embedding_input_test(
375376
vllm_runner,
376377
inputs_per_case,
377378
model,
@@ -416,7 +417,134 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model,
416417
[rescale_video_size(video, factor) for factor in size_factors],
417418
) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)]
418419

419-
run_test(
420+
run_embedding_input_test(
421+
vllm_runner,
422+
inputs_per_case,
423+
model,
424+
dtype=dtype,
425+
max_tokens=max_tokens,
426+
num_logprobs=num_logprobs,
427+
mm_limit=1,
428+
tensor_parallel_size=1,
429+
)
430+
431+
432+
def run_chunked_prefill_test(
433+
vllm_runner: Type[VllmRunner],
434+
inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]],
435+
model: str,
436+
*,
437+
dtype: str,
438+
max_tokens: int,
439+
num_logprobs: int,
440+
mm_limit: int,
441+
tensor_parallel_size: int,
442+
distributed_executor_backend: Optional[str] = None,
443+
):
444+
"""Compare inference result between
445+
chunked prefill disabled and chunked prefill enabled
446+
"""
447+
448+
# NOTE:
449+
# max_model_len should be greater than image_feature_size
450+
with vllm_runner(model,
451+
task="generate",
452+
max_model_len=4000,
453+
max_num_seqs=4,
454+
dtype=dtype,
455+
limit_mm_per_prompt={
456+
"image": mm_limit,
457+
"video": mm_limit
458+
},
459+
tensor_parallel_size=tensor_parallel_size,
460+
distributed_executor_backend=distributed_executor_backend
461+
) as vllm_model:
462+
463+
outputs_per_case = [
464+
vllm_model.generate_greedy_logprobs(prompts,
465+
max_tokens,
466+
num_logprobs=num_logprobs,
467+
images=images or None,
468+
videos=videos or None)
469+
for prompts, images, videos in inputs
470+
]
471+
472+
with vllm_runner(
473+
model,
474+
task="generate",
475+
max_model_len=4000,
476+
max_num_seqs=4,
477+
dtype=dtype,
478+
limit_mm_per_prompt={
479+
"image": mm_limit,
480+
"video": mm_limit
481+
},
482+
tensor_parallel_size=tensor_parallel_size,
483+
distributed_executor_backend=distributed_executor_backend,
484+
enable_chunked_prefill=True,
485+
# should be small enough to ensure prefilling is chunked
486+
max_num_batched_tokens=32,
487+
mm_processor_kwargs={
488+
"max_pixels": 16 * 28 * 28,
489+
}) as vllm_model_chunked:
490+
outputs_per_case_chunked = [
491+
vllm_model_chunked.generate_greedy_logprobs(
492+
prompts,
493+
max_tokens,
494+
num_logprobs=num_logprobs,
495+
images=images or None,
496+
videos=videos or None) for prompts, images, videos in inputs
497+
]
498+
499+
for outputs, \
500+
outputs_chunked \
501+
in zip(outputs_per_case,
502+
outputs_per_case_chunked):
503+
check_logprobs_close(
504+
outputs_0_lst=outputs,
505+
outputs_1_lst=outputs_chunked,
506+
name_0="non_chunked",
507+
name_1="chunked",
508+
)
509+
510+
511+
@pytest.mark.core_model
512+
@pytest.mark.parametrize("model", models)
513+
@pytest.mark.parametrize("dtype", [target_dtype])
514+
@pytest.mark.parametrize("max_tokens", [1])
515+
@pytest.mark.parametrize("num_logprobs", [10])
516+
def test_qwen2_vl_mrope_chunked_prefill(vllm_runner, example_prompts,
517+
model: str, dtype: str,
518+
max_tokens: int,
519+
num_logprobs: int) -> None:
520+
"""
521+
Test Qwen2-VL's chunked prefill with M-RoPE
522+
"""
523+
prompts = [
524+
qwen2_vl_chat_template(IMAGE_PLACEHOLDER, prompt)
525+
for prompt in example_prompts[:1]
526+
]
527+
528+
# 1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs,
529+
# so an image is included in the inputs
530+
# 2. however, Qwen2-VL currently won't work properly
531+
# when chunked prefill is enabled and there are some multi-modal inputs,
532+
# here use a hacky way: provide a **zero-length** image to make it happy
533+
#
534+
# and finally we achieved:
535+
# (1) chunked_prefill enabled; (2) M-RoPE works; to continue our tests
536+
zero_len_image = {
537+
"image_embeds": torch.empty((0, MODEL_HIDDEN_SIZE)),
538+
"image_grid_thw": torch.tensor([[0, 0, 0]])
539+
}
540+
images = [zero_len_image] * len(prompts)
541+
542+
inputs_per_case: List[Tuple[List[str], PromptImageInput,
543+
PromptVideoInput]] = [
544+
(prompts, images, []),
545+
]
546+
547+
run_chunked_prefill_test(
420548
vllm_runner,
421549
inputs_per_case,
422550
model,

vllm/model_executor/layers/rotary_embedding.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,7 @@ def get_input_positions(
847847
vision_end_token_id: int,
848848
spatial_merge_size: int,
849849
context_len: int = 0,
850+
seq_len: Optional[int] = None,
850851
) -> Tuple[List[List[int]], int]:
851852
"""Get mrope input positions and delta value."""
852853

@@ -921,7 +922,7 @@ def get_input_positions(
921922
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
922923

923924
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
924-
llm_positions = llm_positions[:, context_len:]
925+
llm_positions = llm_positions[:, context_len:seq_len]
925926
mrope_position_delta = (llm_positions.max() + 1 -
926927
len(input_tokens)).item()
927928

vllm/worker/model_runner.py

+1
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
700700
spatial_merge_size=hf_config.vision_config.
701701
spatial_merge_size,
702702
context_len=inter_data.context_lens[seq_idx],
703+
seq_len=inter_data.seq_lens[seq_idx],
703704
)
704705

705706
seq_data.mrope_position_delta = mrope_position_delta

0 commit comments

Comments
 (0)