Skip to content

Commit

Permalink
[VLM] Implement merged multimodal processor for Mllama (vllm-project#…
Browse files Browse the repository at this point in the history
…11427)

Signed-off-by: saeediy <saidakbarp@gmail.com>
  • Loading branch information
Isotr0py authored and Said-Akbar committed Mar 7, 2025
1 parent d615607 commit 28cea53
Show file tree
Hide file tree
Showing 8 changed files with 456 additions and 233 deletions.
71 changes: 67 additions & 4 deletions tests/models/encoder_decoder/vision_language/test_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)

from vllm import LLM, SamplingParams
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.model_executor.models.mllama import (MLLAMA_IMAGE_TOKEN_ID,
MllamaForConditionalGeneration)
from vllm.model_executor.models.mllama import MllamaForConditionalGeneration
from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs

Expand All @@ -21,6 +21,7 @@
from ...utils import check_logprobs_close

_LIMIT_IMAGE_PER_PROMPT = 3
MLLAMA_IMAGE_TOKEN_ID = 128256

LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]

Expand Down Expand Up @@ -396,6 +397,64 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
)


@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
def test_explicit_implicit_prompt(
image_assets: _ImageAssets,
model: str,
dtype: str,
max_tokens: int,
):
stop_sign = image_assets[0].pil_image
# yapf: disable
prompts = [
# explicit prompt
{
"encoder_prompt": {
"prompt": "<|image|>",
"multi_modal_data": {"image": stop_sign},
},
"decoder_prompt": {
"prompt_token_ids": [128000, 791, 2262, 315, 279, 2217, 220, 128256, 374], # noqa: E501
}
},
{
"encoder_prompt": "Not <|image|>",
"decoder_prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
},
# implicit prompt
{
"prompt": "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501
"multi_modal_data": {"image": stop_sign},
},
{
"prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
},
]
# yapf: enable
llm = LLM(
model=model,
dtype=dtype,
max_model_len=4096,
max_num_seqs=2,
tensor_parallel_size=1,
enforce_eager=True,
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=max_tokens,
)
outputs = llm.generate(prompts, sampling_params)
n_prompts = len(prompts)
explicit_outputs = outputs[:n_prompts // 2]
implicit_outputs = outputs[n_prompts // 2:]
for exp_output, imp_output in zip(explicit_outputs, implicit_outputs):
assert exp_output.outputs[0].text == imp_output.outputs[0].text


@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
Expand Down Expand Up @@ -458,6 +517,10 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
images=images)


class DummyModel:
image_token_id = MLLAMA_IMAGE_TOKEN_ID


@pytest.mark.core_model
@pytest.mark.parametrize(
"input_indices_and_output",
Expand Down Expand Up @@ -499,7 +562,7 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None:
use_cuda_graph=False,
)

dummy: dict[str, str] = {}
dummy = DummyModel()

cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\
.get_cross_attention_mask(dummy,
Expand Down Expand Up @@ -556,7 +619,7 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
use_cuda_graph=False,
)

dummy: dict[str, str] = {}
dummy = DummyModel()

full_text_row_masked_out_mask = MllamaForConditionalGeneration\
.get_full_text_row_masked_out_mask(dummy,
Expand Down
13 changes: 11 additions & 2 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def _test_processing_correctness(
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
}

tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type == "mllama":
# For Mllama, tokenizer will always add bos_token at the beginning of
# prompt by default, causing hf_processor outputs incorrect token ids.
# So we need use `add_special_tokens=False` here to leave bos_token
# to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False}

for batch_idx in range(num_batches):
mm_data = {
k:
Expand Down Expand Up @@ -122,7 +130,7 @@ def _test_processing_correctness(
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")

baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt),
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
Expand All @@ -131,7 +139,7 @@ def _test_processing_correctness(
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")

cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt),
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
Expand All @@ -155,6 +163,7 @@ def _test_processing_correctness(
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6",
Expand Down
90 changes: 83 additions & 7 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
from typing import List, Mapping, Optional, Union
from typing import List, Mapping, Optional, Tuple, Union, cast

from typing_extensions import assert_never

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup

Expand Down Expand Up @@ -495,6 +496,51 @@ def _build_enc_dec_llm_inputs(
decoder=decoder_inputs,
)

def _separate_enc_dec_inputs_from_mm_processor_outputs(
self,
inputs: SingletonInputs,
decoder_inputs_to_override: Optional[SingletonInputs] = None,
) -> Tuple[SingletonInputs, SingletonInputs]:
"""
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
encoder_inputs: SingletonInputs
decoder_inputs: SingletonInputs
if inputs["type"] == "multimodal":
# Multimodal data inputs
assert ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs)
inputs = cast(MultiModalEncDecInputs, inputs)
encoder_inputs = token_inputs(
prompt=inputs["encoder_prompt"],
prompt_token_ids=inputs["encoder_prompt_token_ids"],
)
if decoder_inputs_to_override is not None:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=decoder_inputs_to_override.get("prompt", ""),
prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_placeholders=inputs["mm_placeholders"],
)
else:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_placeholders=inputs["mm_placeholders"],
)
elif inputs["type"] == "token":
# Text-only inputs
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
decoder_inputs = decoder_inputs_to_override or inputs
else:
assert_never(inputs) # type: ignore[arg-type]
return encoder_inputs, decoder_inputs

def _process_encoder_decoder_prompt(
self,
prompt: PromptType,
Expand Down Expand Up @@ -539,21 +585,35 @@ def _process_encoder_decoder_prompt(
prompt["encoder_prompt"],
request_id=request_id,
)

if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None
else:
decoder_inputs = self._prompt_to_llm_inputs(
decoder_input,
request_id=request_id,
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
encoder_inputs = self._prompt_to_llm_inputs(
inputs = self._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
else:
encoder_inputs = inputs

decoder_inputs = None
decoder_inputs = None

return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)

Expand Down Expand Up @@ -583,13 +643,29 @@ async def _process_encoder_decoder_prompt_async(

encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task)

# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
encoder_inputs = await self._prompt_to_llm_inputs_async(
inputs = await self._prompt_to_llm_inputs_async(
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
else:
encoder_inputs = inputs

decoder_inputs = None
decoder_inputs = None

return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)

Expand Down
3 changes: 2 additions & 1 deletion vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def dummy_data_for_profiling(
)
processor = mm_registry.create_processor(model_config, tokenizer)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data(seq_len)
dummy_data = profiler.get_dummy_data(
seq_len, is_encoder_data=is_encoder_data)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
Expand Down
Loading

0 comments on commit 28cea53

Please # to comment.