Skip to content

Commit 2772fba

Browse files
DarkLight1337BKitor
authored andcommitted
[Bugfix] Cleanup Pixtral HF code (vllm-project#11333)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent a339378 commit 2772fba

File tree

1 file changed

+14
-141
lines changed

1 file changed

+14
-141
lines changed

vllm/model_executor/models/pixtral.py

+14-141
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from PIL import Image
1111
from transformers import PixtralVisionConfig
1212
from transformers.models.pixtral.image_processing_pixtral import (
13-
_num_image_tokens)
13+
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
1414
from transformers.models.pixtral.modeling_pixtral import (
1515
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
1616

1717
from vllm.attention import AttentionMetadata
18-
from vllm.config import ModelConfig, VllmConfig
18+
from vllm.config import VllmConfig
1919
from vllm.distributed import divide, get_tensor_model_parallel_world_size
2020
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
2121
InputContext, token_inputs)
@@ -27,19 +27,17 @@
2727
from vllm.model_executor.layers.quantization import QuantizationConfig
2828
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2929
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30-
from vllm.model_executor.models.utils import merge_multimodal_embeddings
3130
from vllm.model_executor.sampling_metadata import SamplingMetadata
3231
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
3332
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
3433
from vllm.multimodal.utils import (cached_get_tokenizer,
3534
consecutive_placeholder_ranges,
3635
resolve_visual_encoder_outputs)
3736
from vllm.sequence import IntermediateTensors, SequenceData
38-
from vllm.transformers_utils.processor import cached_get_processor
39-
from vllm.utils import is_list_of
4037

4138
from .interfaces import SupportsMultiModal, SupportsPP
42-
from .utils import init_vllm_registered_model, maybe_prefix
39+
from .utils import (init_vllm_registered_model, maybe_prefix,
40+
merge_multimodal_embeddings)
4341

4442
try:
4543
from xformers import ops as xops
@@ -699,37 +697,14 @@ def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
699697
return grid_length * grid_length
700698

701699

702-
def get_max_pixtral_hf_image_feature_size(
703-
hf_config: PixtralVisionConfig) -> int:
704-
return get_pixtral_hf_num_patches(image_size=hf_config.image_size,
705-
patch_size=hf_config.patch_size)
706-
707-
708700
def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
709-
return get_max_pixtral_hf_image_feature_size(hf_config)
701+
grid_length = get_pixtral_hf_patch_grid_length(
702+
image_size=hf_config.image_size,
703+
patch_size=hf_config.patch_size,
704+
)
710705

711-
712-
def dummy_seq_data_for_pixtral_hf(
713-
hf_config: PixtralVisionConfig,
714-
seq_len: int,
715-
num_images: int,
716-
*,
717-
image_token_id: int,
718-
image_feature_size_override: Optional[int] = None,
719-
mm_key: str = "image"):
720-
if image_feature_size_override is None:
721-
image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config)
722-
else:
723-
image_feature_size = image_feature_size_override
724-
725-
return SequenceData.from_prompt_token_counts(
726-
(image_token_id, image_feature_size * num_images),
727-
(0, seq_len - image_feature_size * num_images),
728-
), {
729-
mm_key:
730-
consecutive_placeholder_ranges(num_items=num_images,
731-
item_size=image_feature_size)
732-
}
706+
# Consider the image_break_token
707+
return (grid_length + 1) * grid_length
733708

734709

735710
def dummy_image_for_pixtral_hf(
@@ -763,116 +738,14 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
763738
image_width = int(numpy.ceil(image_width / ratio))
764739
image_height = int(numpy.ceil(image_height / ratio))
765740

766-
num_height_tokens, num_width_tokens = _num_image_tokens(
767-
(image_height, image_width), (patch_height, patch_width))
741+
num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens(
742+
(image_height, image_width),
743+
(patch_height, patch_width),
744+
)
768745

769746
return num_width_tokens, num_height_tokens
770747

771748

772-
def input_processor_for_pixtral_hf(
773-
model_config: ModelConfig,
774-
hf_config: PixtralVisionConfig,
775-
inputs: DecoderOnlyInputs,
776-
*,
777-
image_token_id: int,
778-
image_feature_size_override: Optional[Union[int, List[int]]] = None,
779-
) -> DecoderOnlyInputs:
780-
assert image_feature_size_override is None, (
781-
"image_feature_size_override is not supported for Pixtral")
782-
783-
multi_modal_data = inputs.get("multi_modal_data")
784-
if multi_modal_data is None or "image" not in multi_modal_data:
785-
return inputs
786-
787-
processor = cached_get_processor(model_config.model)
788-
789-
image_data = multi_modal_data["image"]
790-
if isinstance(image_data, Image.Image):
791-
image_data = [image_data]
792-
elif not is_list_of(image_data, Image.Image):
793-
raise TypeError(f"Invalid image type: {type(image_data)}")
794-
795-
new_prompt = inputs.get("prompt")
796-
new_token_ids = inputs["prompt_token_ids"]
797-
798-
image_token = processor.image_token
799-
image_break_token = processor.image_break_token
800-
image_end_token = processor.image_end_token
801-
802-
# Update new_prompt if present
803-
if new_prompt:
804-
parts = new_prompt.split(image_token)
805-
assert len(parts) - 1 == len(image_data)
806-
new_parts = [parts[0]] # Start with the part before any image tokens
807-
808-
for image, next_part in zip(image_data, parts[1:]):
809-
w, h = image.size
810-
(num_width_tokens,
811-
num_height_tokens) = get_pixtral_hf_image_feature_size(
812-
hf_config, image_width=w, image_height=h)
813-
814-
replace_tokens = [image_token] * num_width_tokens + [
815-
image_break_token
816-
]
817-
replace_tokens = replace_tokens * num_height_tokens
818-
replace_tokens[-1] = image_end_token
819-
820-
new_parts.append("".join(replace_tokens))
821-
new_parts.append(next_part)
822-
823-
new_prompt = "".join(new_parts)
824-
825-
# Update new_token_ids
826-
convert_tokens_to_ids = processor.tokenizer.convert_tokens_to_ids
827-
image_token_id = convert_tokens_to_ids(image_token)
828-
image_break_id = convert_tokens_to_ids(image_break_token)
829-
image_end_id = convert_tokens_to_ids(image_end_token)
830-
placeholder_token_id = -999
831-
# Find all image token indices at once
832-
placeholder_indices = [
833-
idx for idx, token_id in enumerate(new_token_ids)
834-
if token_id == image_token_id
835-
]
836-
assert len(placeholder_indices) == len(image_data)
837-
replace_tokens_list = []
838-
for placeholder_idx, image in zip(placeholder_indices, image_data):
839-
new_token_ids[placeholder_idx] = placeholder_token_id
840-
841-
w, h = image.size
842-
(num_width_tokens,
843-
num_height_tokens) = get_pixtral_hf_image_feature_size(hf_config,
844-
image_width=w,
845-
image_height=h)
846-
847-
replace_tokens = [image_token_id] * num_width_tokens + [image_break_id]
848-
replace_tokens = replace_tokens * num_height_tokens
849-
replace_tokens[-1] = image_end_id
850-
replace_tokens_list.append(replace_tokens)
851-
852-
reverse_offsets: List[int] = []
853-
# Backward iteration for replacement without affecting known indices
854-
for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices),
855-
reversed(replace_tokens_list)):
856-
reverse_offsets.append(
857-
len(new_token_ids) - placeholder_idx + len(replace_tokens))
858-
new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens
859-
860-
placeholder_ranges: List[PlaceholderRange] = []
861-
for reverse_offset, replace_tokens in zip(reversed(reverse_offsets),
862-
replace_tokens_list):
863-
placeholder_ranges.append(
864-
PlaceholderRange(
865-
offset=len(new_token_ids) - reverse_offset,
866-
length=len(replace_tokens),
867-
))
868-
869-
# NOTE: Create a defensive copy of the original inputs
870-
return token_inputs(prompt_token_ids=new_token_ids,
871-
prompt=new_prompt,
872-
multi_modal_data=multi_modal_data,
873-
multi_modal_placeholders={"image": placeholder_ranges})
874-
875-
876749
class PixtralHFMLP(nn.Module):
877750

878751
def __init__(

0 commit comments

Comments
 (0)