|
10 | 10 | from PIL import Image
|
11 | 11 | from transformers import PixtralVisionConfig
|
12 | 12 | from transformers.models.pixtral.image_processing_pixtral import (
|
13 |
| - _num_image_tokens) |
| 13 | + _num_image_tokens as _get_pixtral_hf_num_image_tokens) |
14 | 14 | from transformers.models.pixtral.modeling_pixtral import (
|
15 | 15 | PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
16 | 16 |
|
17 | 17 | from vllm.attention import AttentionMetadata
|
18 |
| -from vllm.config import ModelConfig, VllmConfig |
| 18 | +from vllm.config import VllmConfig |
19 | 19 | from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
20 | 20 | from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
21 | 21 | InputContext, token_inputs)
|
|
27 | 27 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
28 | 28 | from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
29 | 29 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30 |
| -from vllm.model_executor.models.utils import merge_multimodal_embeddings |
31 | 30 | from vllm.model_executor.sampling_metadata import SamplingMetadata
|
32 | 31 | from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
33 | 32 | from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
|
34 | 33 | from vllm.multimodal.utils import (cached_get_tokenizer,
|
35 | 34 | consecutive_placeholder_ranges,
|
36 | 35 | resolve_visual_encoder_outputs)
|
37 | 36 | from vllm.sequence import IntermediateTensors, SequenceData
|
38 |
| -from vllm.transformers_utils.processor import cached_get_processor |
39 |
| -from vllm.utils import is_list_of |
40 | 37 |
|
41 | 38 | from .interfaces import SupportsMultiModal, SupportsPP
|
42 |
| -from .utils import init_vllm_registered_model, maybe_prefix |
| 39 | +from .utils import (init_vllm_registered_model, maybe_prefix, |
| 40 | + merge_multimodal_embeddings) |
43 | 41 |
|
44 | 42 | try:
|
45 | 43 | from xformers import ops as xops
|
@@ -699,37 +697,14 @@ def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
|
699 | 697 | return grid_length * grid_length
|
700 | 698 |
|
701 | 699 |
|
702 |
| -def get_max_pixtral_hf_image_feature_size( |
703 |
| - hf_config: PixtralVisionConfig) -> int: |
704 |
| - return get_pixtral_hf_num_patches(image_size=hf_config.image_size, |
705 |
| - patch_size=hf_config.patch_size) |
706 |
| - |
707 |
| - |
708 | 700 | def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
|
709 |
| - return get_max_pixtral_hf_image_feature_size(hf_config) |
| 701 | + grid_length = get_pixtral_hf_patch_grid_length( |
| 702 | + image_size=hf_config.image_size, |
| 703 | + patch_size=hf_config.patch_size, |
| 704 | + ) |
710 | 705 |
|
711 |
| - |
712 |
| -def dummy_seq_data_for_pixtral_hf( |
713 |
| - hf_config: PixtralVisionConfig, |
714 |
| - seq_len: int, |
715 |
| - num_images: int, |
716 |
| - *, |
717 |
| - image_token_id: int, |
718 |
| - image_feature_size_override: Optional[int] = None, |
719 |
| - mm_key: str = "image"): |
720 |
| - if image_feature_size_override is None: |
721 |
| - image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config) |
722 |
| - else: |
723 |
| - image_feature_size = image_feature_size_override |
724 |
| - |
725 |
| - return SequenceData.from_prompt_token_counts( |
726 |
| - (image_token_id, image_feature_size * num_images), |
727 |
| - (0, seq_len - image_feature_size * num_images), |
728 |
| - ), { |
729 |
| - mm_key: |
730 |
| - consecutive_placeholder_ranges(num_items=num_images, |
731 |
| - item_size=image_feature_size) |
732 |
| - } |
| 706 | + # Consider the image_break_token |
| 707 | + return (grid_length + 1) * grid_length |
733 | 708 |
|
734 | 709 |
|
735 | 710 | def dummy_image_for_pixtral_hf(
|
@@ -763,116 +738,14 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
|
763 | 738 | image_width = int(numpy.ceil(image_width / ratio))
|
764 | 739 | image_height = int(numpy.ceil(image_height / ratio))
|
765 | 740 |
|
766 |
| - num_height_tokens, num_width_tokens = _num_image_tokens( |
767 |
| - (image_height, image_width), (patch_height, patch_width)) |
| 741 | + num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens( |
| 742 | + (image_height, image_width), |
| 743 | + (patch_height, patch_width), |
| 744 | + ) |
768 | 745 |
|
769 | 746 | return num_width_tokens, num_height_tokens
|
770 | 747 |
|
771 | 748 |
|
772 |
| -def input_processor_for_pixtral_hf( |
773 |
| - model_config: ModelConfig, |
774 |
| - hf_config: PixtralVisionConfig, |
775 |
| - inputs: DecoderOnlyInputs, |
776 |
| - *, |
777 |
| - image_token_id: int, |
778 |
| - image_feature_size_override: Optional[Union[int, List[int]]] = None, |
779 |
| -) -> DecoderOnlyInputs: |
780 |
| - assert image_feature_size_override is None, ( |
781 |
| - "image_feature_size_override is not supported for Pixtral") |
782 |
| - |
783 |
| - multi_modal_data = inputs.get("multi_modal_data") |
784 |
| - if multi_modal_data is None or "image" not in multi_modal_data: |
785 |
| - return inputs |
786 |
| - |
787 |
| - processor = cached_get_processor(model_config.model) |
788 |
| - |
789 |
| - image_data = multi_modal_data["image"] |
790 |
| - if isinstance(image_data, Image.Image): |
791 |
| - image_data = [image_data] |
792 |
| - elif not is_list_of(image_data, Image.Image): |
793 |
| - raise TypeError(f"Invalid image type: {type(image_data)}") |
794 |
| - |
795 |
| - new_prompt = inputs.get("prompt") |
796 |
| - new_token_ids = inputs["prompt_token_ids"] |
797 |
| - |
798 |
| - image_token = processor.image_token |
799 |
| - image_break_token = processor.image_break_token |
800 |
| - image_end_token = processor.image_end_token |
801 |
| - |
802 |
| - # Update new_prompt if present |
803 |
| - if new_prompt: |
804 |
| - parts = new_prompt.split(image_token) |
805 |
| - assert len(parts) - 1 == len(image_data) |
806 |
| - new_parts = [parts[0]] # Start with the part before any image tokens |
807 |
| - |
808 |
| - for image, next_part in zip(image_data, parts[1:]): |
809 |
| - w, h = image.size |
810 |
| - (num_width_tokens, |
811 |
| - num_height_tokens) = get_pixtral_hf_image_feature_size( |
812 |
| - hf_config, image_width=w, image_height=h) |
813 |
| - |
814 |
| - replace_tokens = [image_token] * num_width_tokens + [ |
815 |
| - image_break_token |
816 |
| - ] |
817 |
| - replace_tokens = replace_tokens * num_height_tokens |
818 |
| - replace_tokens[-1] = image_end_token |
819 |
| - |
820 |
| - new_parts.append("".join(replace_tokens)) |
821 |
| - new_parts.append(next_part) |
822 |
| - |
823 |
| - new_prompt = "".join(new_parts) |
824 |
| - |
825 |
| - # Update new_token_ids |
826 |
| - convert_tokens_to_ids = processor.tokenizer.convert_tokens_to_ids |
827 |
| - image_token_id = convert_tokens_to_ids(image_token) |
828 |
| - image_break_id = convert_tokens_to_ids(image_break_token) |
829 |
| - image_end_id = convert_tokens_to_ids(image_end_token) |
830 |
| - placeholder_token_id = -999 |
831 |
| - # Find all image token indices at once |
832 |
| - placeholder_indices = [ |
833 |
| - idx for idx, token_id in enumerate(new_token_ids) |
834 |
| - if token_id == image_token_id |
835 |
| - ] |
836 |
| - assert len(placeholder_indices) == len(image_data) |
837 |
| - replace_tokens_list = [] |
838 |
| - for placeholder_idx, image in zip(placeholder_indices, image_data): |
839 |
| - new_token_ids[placeholder_idx] = placeholder_token_id |
840 |
| - |
841 |
| - w, h = image.size |
842 |
| - (num_width_tokens, |
843 |
| - num_height_tokens) = get_pixtral_hf_image_feature_size(hf_config, |
844 |
| - image_width=w, |
845 |
| - image_height=h) |
846 |
| - |
847 |
| - replace_tokens = [image_token_id] * num_width_tokens + [image_break_id] |
848 |
| - replace_tokens = replace_tokens * num_height_tokens |
849 |
| - replace_tokens[-1] = image_end_id |
850 |
| - replace_tokens_list.append(replace_tokens) |
851 |
| - |
852 |
| - reverse_offsets: List[int] = [] |
853 |
| - # Backward iteration for replacement without affecting known indices |
854 |
| - for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices), |
855 |
| - reversed(replace_tokens_list)): |
856 |
| - reverse_offsets.append( |
857 |
| - len(new_token_ids) - placeholder_idx + len(replace_tokens)) |
858 |
| - new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens |
859 |
| - |
860 |
| - placeholder_ranges: List[PlaceholderRange] = [] |
861 |
| - for reverse_offset, replace_tokens in zip(reversed(reverse_offsets), |
862 |
| - replace_tokens_list): |
863 |
| - placeholder_ranges.append( |
864 |
| - PlaceholderRange( |
865 |
| - offset=len(new_token_ids) - reverse_offset, |
866 |
| - length=len(replace_tokens), |
867 |
| - )) |
868 |
| - |
869 |
| - # NOTE: Create a defensive copy of the original inputs |
870 |
| - return token_inputs(prompt_token_ids=new_token_ids, |
871 |
| - prompt=new_prompt, |
872 |
| - multi_modal_data=multi_modal_data, |
873 |
| - multi_modal_placeholders={"image": placeholder_ranges}) |
874 |
| - |
875 |
| - |
876 | 749 | class PixtralHFMLP(nn.Module):
|
877 | 750 |
|
878 | 751 | def __init__(
|
|
0 commit comments