|
18 | 18 |
|
19 | 19 | IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
|
20 | 20 | VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
|
| 21 | +MODEL_HIDDEN_SIZE = 1536 |
21 | 22 |
|
22 | 23 |
|
23 | 24 | def qwen2_vl_chat_template(*query):
|
@@ -230,7 +231,7 @@ def batch_make_video_embeddings(
|
230 | 231 | return result
|
231 | 232 |
|
232 | 233 |
|
233 |
| -def run_test( |
| 234 | +def run_embedding_input_test( |
234 | 235 | vllm_runner: Type[VllmRunner],
|
235 | 236 | inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]],
|
236 | 237 | model: str,
|
@@ -326,7 +327,7 @@ def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
|
326 | 327 | [],
|
327 | 328 | ) for image, prompt in zip(images, IMAGE_PROMPTS)]
|
328 | 329 |
|
329 |
| - run_test( |
| 330 | + run_embedding_input_test( |
330 | 331 | vllm_runner,
|
331 | 332 | inputs_per_case,
|
332 | 333 | model,
|
@@ -371,7 +372,7 @@ def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets,
|
371 | 372 | [],
|
372 | 373 | )]
|
373 | 374 |
|
374 |
| - run_test( |
| 375 | + run_embedding_input_test( |
375 | 376 | vllm_runner,
|
376 | 377 | inputs_per_case,
|
377 | 378 | model,
|
@@ -416,7 +417,134 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model,
|
416 | 417 | [rescale_video_size(video, factor) for factor in size_factors],
|
417 | 418 | ) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)]
|
418 | 419 |
|
419 |
| - run_test( |
| 420 | + run_embedding_input_test( |
| 421 | + vllm_runner, |
| 422 | + inputs_per_case, |
| 423 | + model, |
| 424 | + dtype=dtype, |
| 425 | + max_tokens=max_tokens, |
| 426 | + num_logprobs=num_logprobs, |
| 427 | + mm_limit=1, |
| 428 | + tensor_parallel_size=1, |
| 429 | + ) |
| 430 | + |
| 431 | + |
| 432 | +def run_chunked_prefill_test( |
| 433 | + vllm_runner: Type[VllmRunner], |
| 434 | + inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]], |
| 435 | + model: str, |
| 436 | + *, |
| 437 | + dtype: str, |
| 438 | + max_tokens: int, |
| 439 | + num_logprobs: int, |
| 440 | + mm_limit: int, |
| 441 | + tensor_parallel_size: int, |
| 442 | + distributed_executor_backend: Optional[str] = None, |
| 443 | +): |
| 444 | + """Compare inference result between |
| 445 | + chunked prefill disabled and chunked prefill enabled |
| 446 | + """ |
| 447 | + |
| 448 | + # NOTE: |
| 449 | + # max_model_len should be greater than image_feature_size |
| 450 | + with vllm_runner(model, |
| 451 | + task="generate", |
| 452 | + max_model_len=4000, |
| 453 | + max_num_seqs=4, |
| 454 | + dtype=dtype, |
| 455 | + limit_mm_per_prompt={ |
| 456 | + "image": mm_limit, |
| 457 | + "video": mm_limit |
| 458 | + }, |
| 459 | + tensor_parallel_size=tensor_parallel_size, |
| 460 | + distributed_executor_backend=distributed_executor_backend |
| 461 | + ) as vllm_model: |
| 462 | + |
| 463 | + outputs_per_case = [ |
| 464 | + vllm_model.generate_greedy_logprobs(prompts, |
| 465 | + max_tokens, |
| 466 | + num_logprobs=num_logprobs, |
| 467 | + images=images or None, |
| 468 | + videos=videos or None) |
| 469 | + for prompts, images, videos in inputs |
| 470 | + ] |
| 471 | + |
| 472 | + with vllm_runner( |
| 473 | + model, |
| 474 | + task="generate", |
| 475 | + max_model_len=4000, |
| 476 | + max_num_seqs=4, |
| 477 | + dtype=dtype, |
| 478 | + limit_mm_per_prompt={ |
| 479 | + "image": mm_limit, |
| 480 | + "video": mm_limit |
| 481 | + }, |
| 482 | + tensor_parallel_size=tensor_parallel_size, |
| 483 | + distributed_executor_backend=distributed_executor_backend, |
| 484 | + enable_chunked_prefill=True, |
| 485 | + # should be small enough to ensure prefilling is chunked |
| 486 | + max_num_batched_tokens=32, |
| 487 | + mm_processor_kwargs={ |
| 488 | + "max_pixels": 16 * 28 * 28, |
| 489 | + }) as vllm_model_chunked: |
| 490 | + outputs_per_case_chunked = [ |
| 491 | + vllm_model_chunked.generate_greedy_logprobs( |
| 492 | + prompts, |
| 493 | + max_tokens, |
| 494 | + num_logprobs=num_logprobs, |
| 495 | + images=images or None, |
| 496 | + videos=videos or None) for prompts, images, videos in inputs |
| 497 | + ] |
| 498 | + |
| 499 | + for outputs, \ |
| 500 | + outputs_chunked \ |
| 501 | + in zip(outputs_per_case, |
| 502 | + outputs_per_case_chunked): |
| 503 | + check_logprobs_close( |
| 504 | + outputs_0_lst=outputs, |
| 505 | + outputs_1_lst=outputs_chunked, |
| 506 | + name_0="non_chunked", |
| 507 | + name_1="chunked", |
| 508 | + ) |
| 509 | + |
| 510 | + |
| 511 | +@pytest.mark.core_model |
| 512 | +@pytest.mark.parametrize("model", models) |
| 513 | +@pytest.mark.parametrize("dtype", [target_dtype]) |
| 514 | +@pytest.mark.parametrize("max_tokens", [1]) |
| 515 | +@pytest.mark.parametrize("num_logprobs", [10]) |
| 516 | +def test_qwen2_vl_mrope_chunked_prefill(vllm_runner, example_prompts, |
| 517 | + model: str, dtype: str, |
| 518 | + max_tokens: int, |
| 519 | + num_logprobs: int) -> None: |
| 520 | + """ |
| 521 | + Test Qwen2-VL's chunked prefill with M-RoPE |
| 522 | + """ |
| 523 | + prompts = [ |
| 524 | + qwen2_vl_chat_template(IMAGE_PLACEHOLDER, prompt) |
| 525 | + for prompt in example_prompts[:1] |
| 526 | + ] |
| 527 | + |
| 528 | + # 1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs, |
| 529 | + # so an image is included in the inputs |
| 530 | + # 2. however, Qwen2-VL currently won't work properly |
| 531 | + # when chunked prefill is enabled and there are some multi-modal inputs, |
| 532 | + # here use a hacky way: provide a **zero-length** image to make it happy |
| 533 | + # |
| 534 | + # and finally we achieved: |
| 535 | + # (1) chunked_prefill enabled; (2) M-RoPE works; to continue our tests |
| 536 | + zero_len_image = { |
| 537 | + "image_embeds": torch.empty((0, MODEL_HIDDEN_SIZE)), |
| 538 | + "image_grid_thw": torch.tensor([[0, 0, 0]]) |
| 539 | + } |
| 540 | + images = [zero_len_image] * len(prompts) |
| 541 | + |
| 542 | + inputs_per_case: List[Tuple[List[str], PromptImageInput, |
| 543 | + PromptVideoInput]] = [ |
| 544 | + (prompts, images, []), |
| 545 | + ] |
| 546 | + |
| 547 | + run_chunked_prefill_test( |
420 | 548 | vllm_runner,
|
421 | 549 | inputs_per_case,
|
422 | 550 | model,
|
|
0 commit comments