diff --git a/src/model.py b/src/model.py index 7a135dc..d9c0976 100644 --- a/src/model.py +++ b/src/model.py @@ -46,7 +46,11 @@ from vllm.utils import random_uuid from utils.metrics import VllmStatLogger -from utils.vllm_backend_utils import TritonSamplingParams +from utils.vllm_backend_utils import ( + TritonSamplingParams, + _get_llama3_prompt, + _get_qwen_v2_5_prompt, +) _VLLM_ENGINE_ARGS_FILENAME = "model.json" _MULTI_LORA_ARGS_FILENAME = "multi_lora.json" @@ -531,11 +535,15 @@ def _get_input_tensors(self, request): image_rgb = Image.open(BytesIO(image_b)).convert("RGB") images_vllm.append(image_rgb) if len(images_vllm) > 0: - prompt = { - "prompt": prompt, - "multi_modal_data": {"image": images_vllm}, - } - + if "llama-3" in self.args["model_name"].lower(): + prompt = _get_llama3_prompt(question=prompt, images=images_vllm) + if "qwen2.5" in self.args["model_name"].lower(): + prompt = _get_qwen_v2_5_prompt(question=prompt, images=images_vllm) + else: + self.logger.log_warning( + "This model does not support multi-modal input. The image will not be used.\n" + "Supported models: llama-3, qwen2.5" + ) # stream stream = pb_utils.get_input_tensor_by_name(request, "stream") if stream: diff --git a/src/utils/vllm_backend_utils.py b/src/utils/vllm_backend_utils.py index 8d330fb..aca1bed 100644 --- a/src/utils/vllm_backend_utils.py +++ b/src/utils/vllm_backend_utils.py @@ -27,6 +27,7 @@ import json from typing import Optional +from PIL import Image from vllm.sampling_params import GuidedDecodingParams, SamplingParams @@ -98,3 +99,25 @@ def from_dict( f"[vllm] Was trying to create `TritonSamplingParams`, but got exception: {e}" ) return None + + +def _get_llama3_prompt(question, images: list[Image.Image]) -> dict: + prompt = { + "prompt": question, + "multi_modal_data": {"image": images}, + } + return prompt + + +def _get_qwen_v2_5_prompt(question, images: list[Image.Image]) -> dict: + placeholder = "<|image_pad|>" + prompt = { + "prompt": ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ), + "multi_modal_data": {"image": images}, + } + return prompt