Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[V1] Fix multimodal profiling for Molmo #11325

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,11 @@ def image_input_mapper_for_molmo(
data: object,
):
if isinstance(data, list):
assert len(data) == 1, "Molmo supports only one image per prompt."
data = data[0]

# Remove unused dummy PIL image
data.pop('raw_mm_data', None)
return MultiModalKwargs(data)


Expand Down Expand Up @@ -974,6 +978,7 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
dummy_imgdata = {
"images": out["images"],
"image_input_idx": out["image_input_idx"],
"raw_mm_data": dummy_image,
}
if "image_masks" in out:
dummy_imgdata["image_masks"] = out["image_masks"]
Expand Down
19 changes: 17 additions & 2 deletions vllm/v1/engine/mm_input_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,31 @@ class MMHasher:
def __init__(self):
pass

def hash_mm_data(
def hash_dummy_mm_data(
self,
mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
"""Hash user-defined dummy multimodal data used for profiling."""

if mm_data is None:
return None

image_inputs = mm_data['image']

# This is a temporary workaround for models (e.g, Molmo) that
# process multimodal data in the input processor (therefore
# image_inputs is MultiModalKwargs instead of raw input format).
# `raw_mm_data` with the original input format is expected
# in this case.
if isinstance(image_inputs, dict):
assert "raw_mm_data" in image_inputs and isinstance(
image_inputs["raw_mm_data"], PIL.Image.Image)
image_inputs = image_inputs.pop("raw_mm_data")

return self.hash_images(image_inputs)

def hash_prompt(self, prompt: PromptType) -> Optional[List[str]]:
def hash_prompt_mm_data(self, prompt: PromptType) -> Optional[List[str]]:
"""Hash multimodal data in the user input prompt if they exist."""

if "multi_modal_data" not in prompt:
return None

Expand All @@ -171,6 +185,7 @@ def hash_prompt(self, prompt: PromptType) -> Optional[List[str]]:
return self.hash_images(image_inputs)

def hash_images(self, image_inputs) -> Optional[List[str]]:
"""Hash PIL image objects to strings."""
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
assert len(image_inputs) > 0
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def process_inputs(
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash_prompt(prompt)
mm_hashes = self.mm_hasher.hash_prompt_mm_data(prompt)

# Process inputs.
preprocessed_inputs = self.input_preprocessor.preprocess(
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def profile_run(self) -> None:
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash_mm_data(dummy_mm_data)
mm_hashes = self.mm_hasher.hash_dummy_mm_data(dummy_mm_data)

dummy_mm_kwargs = self.mm_input_mapper_client.process_inputs(
mm_data=dummy_mm_data,
Expand Down
Loading