Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Gaudi][Model] Qwen2.5-vl #870

Open
wants to merge 34 commits into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
938ef83
Initial commit
ssarkar2 Feb 6, 2025
a3f884b
Comments to trace execution diff between cpu/hpu
ssarkar2 Feb 6, 2025
c83c882
minor
ssarkar2 Feb 7, 2025
c5f65f9
_validate_and_reshape_mm_tensor looks buggy...
ssarkar2 Feb 7, 2025
fca160d
Some comments regd buggy hpu graphs
ssarkar2 Feb 7, 2025
095dbbd
Return early to prevent mem profiling
ssarkar2 Feb 7, 2025
f557d99
Initial commit for the Qwen 2.5 VL
pallavijaini0525 Feb 7, 2025
8c7a2b3
workaround to make HPU graphs work. disable_tensor_cache set to false.
ssarkar2 Feb 9, 2025
22bc3ef
adding qwen2.5-vl to hpu + small cleanups
malkomes Feb 10, 2025
d4a721c
removing duplicates CPU
malkomes Feb 10, 2025
5474d9b
small changes to work with llama-3.2-vl
malkomes Feb 13, 2025
008fbb5
skip profile_run for now
malkomes Feb 16, 2025
f48d6fc
reshape positions in MRotaryEmbedding for HPU
malkomes Feb 18, 2025
4caf383
input positions [3, seq_len] or [seq_len,] for Qwen2.5vl
malkomes Feb 21, 2025
998d090
fix the decoder
malkomes Feb 24, 2025
cd1bbe0
comment prints
malkomes Feb 24, 2025
99f8e9f
cleanup
malkomes Feb 26, 2025
9eac068
polishing
malkomes Feb 26, 2025
dcc2c6c
add type ignore
malkomes Feb 26, 2025
7c5871b
set HPU_DISABLE_TENSOR_CACHE to false for Qwen2.5vl
malkomes Feb 26, 2025
fc9e7ee
make lint happy?
malkomes Feb 26, 2025
67b696e
Change torch dtype to bflat16 for qwen2.5-VL test
jiminha Feb 26, 2025
cf97bed
fea(): Added the tests requirements
imangohari1 Feb 27, 2025
c986f8d
add check_transformers to qwen2_5_VL
malkomes Feb 27, 2025
08b35bf
improving code and comments
malkomes Feb 27, 2025
75eb21b
lint
malkomes Feb 27, 2025
70ef940
remove Optinal
malkomes Feb 27, 2025
15d735c
lint qwen2_5_vl
malkomes Feb 27, 2025
f6b95f8
add reviewers suggestions
malkomes Feb 28, 2025
175a927
lint
malkomes Feb 28, 2025
5baa1ed
remove blank line
malkomes Feb 28, 2025
7fe109a
input_mrope_positions if/else simplifications
malkomes Mar 10, 2025
264676d
Enable FusedSDPA for Qwen2.5 VL
jiminha Mar 10, 2025
cb09a4b
Lint fix
jiminha Mar 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM

- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used. `1` is the default.
- `PT_HPU_ENABLE_LAZY_COLLECTIVES` must be set to `true` for tensor parallel inference with HPU Graphs.
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava model.
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava and qwen models.
- `VLLM_PROMPT_USE_FLEX_ATTENTION` is enabled only for llama model, and allows to use torch.nn.attention.flex_attention instead of FusedSDPA. Note, this requires `VLLM_PROMPT_USE_FUSEDSDPA=0`

# Quantization, FP8 Inference and Model Calibration Process
Expand Down
1 change: 1 addition & 0 deletions requirements-hpu-qwen2_5_vl.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers @ git+https://github.com/huggingface/transformers.git@6b550462139655d488d4c663086a63e98713c6b9

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not add new reuqirement file per model. Why is a specific sha required? I believe this should be added to readme rather.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen2.5-VL is officially supported from Transformer v4.49.0. However currently our VLLM-fork is out of date and support only v4.48.3. v4.48.3 doesn't support qwen2.5-VL though, and the vllm-fork code is also out of date, and can't use 4.49.

 File "/root/tf/qwen/vllm-fork-w2/vllm/model_executor/models/registry.py", line 370, in _raise_for_unsupported
    raise ValueError(
ValueError: Model architectures ['Qwen2_5_VLForConditionalGeneration'] failed to be inspected. Please check the logs for more details.

For now, this specific commit works for qwen2_5_VL without changing too much. Once we update VLLM-fork to the latest, and transformer to 4.49, all of these can go away.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michalkuligowski FYI: We raised this error on upstream vllm repo, and they mentioned it's bc of the vllm-fork version. vllm-project#12932 (comment)

7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,12 @@ def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
return x

if device is None:
device = "cpu" if current_platform.is_cpu() else "cuda"
if current_platform.is_hpu():
device = "hpu"
elif current_platform.is_cpu():
device = "cpu"
else:
device = "cuda"

if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}
Expand Down
1 change: 1 addition & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
dtype=("bfloat16" if current_platform.is_hpu() else "half")
),
#### Extended model tests
"aria": VLMTestInfo(
Expand Down
2 changes: 1 addition & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def check_available_online(
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
min_transformers_version="4.48.9"), # noqa: E501

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this decreased?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the comment above related to transformer version.

"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b",
trust_remote_code=True),
# [Encoder-decoder]
Expand Down
44 changes: 34 additions & 10 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.platforms import _Backend
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope

Expand All @@ -71,6 +71,7 @@
from .vision import get_vit_attn_backend

logger = init_logger(__name__)
is_hpu = current_platform.is_hpu()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used in one place here, so I think you dont need to save a variable, this will make as little changes to model file as possible

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need this for FusedSDPA, will update the code.


# === Vision Inputs === #

Expand Down Expand Up @@ -312,10 +313,14 @@ def forward(
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
for x in [q_i, k_i, v_i])
output_i = F.scaled_dot_product_attention(q_i,
k_i,
v_i,
dropout_p=0.0)
if is_hpu:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0)
else:
output_i = F.scaled_dot_product_attention(q_i,
k_i,
v_i,
dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
Expand Down Expand Up @@ -612,11 +617,30 @@ def forward(

# windows attention
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

if is_hpu:
# NOTE: unique_consecutive is a dynamic operation
# we are using `remove_duplicates_cpu` instead
def remove_duplicates_cpu(a):
return [
a[i] for i in range(len(a)) if i == 0 or a[i - 1] != a[i]
]

cu_window_seqlens = remove_duplicates_cpu(cu_window_seqlens)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype
if torch.jit.is_tracing() else torch.int32)

else:
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype
if torch.jit.is_tracing() else torch.int32)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
Expand Down
30 changes: 30 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,36 @@ def make_tensor_with_pad(
return tensor


def make_mrope_positions_tensor_with_pad( \
input_positions: List[List[int]],
input_mrope_positions: List[List[List[int]]],
max_prompt_len: int,
pad: int) -> List[List[int]]:
# If no mrope positions, returns a flatten (seq_len,)
if all(mrope_position is None for mrope_position in input_mrope_positions):
return make_tensor_with_pad(input_positions,
max_len=max_prompt_len,
pad=0,
dtype=torch.long,
device='cpu').flatten()
# Otherwise, Qwen2.5-VL expects positions in a (3, seq_len)
# we are going to pad each seq_data in the list
# using either MRope values or regular position
mrope_input_positions: List[List[int]] = [[] for _ in range(3)]
for idx in range(3):
for b_idx, input_mrope_position in enumerate(input_mrope_positions):
if input_mrope_position is not None:
positions = input_mrope_position[idx]
else:
positions = input_positions[b_idx]
padding_size = max_prompt_len - len(positions)
assert padding_size >= 0
padded_positions = positions \
+ (max_prompt_len - len(positions)) * [pad]
mrope_input_positions[idx].extend(padded_positions)
return torch.tensor(mrope_input_positions, dtype=torch.long, device='cpu')


def make_tensor_with_pad_align(
x: List[List[T]],
pad: T,
Expand Down
97 changes: 84 additions & 13 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc
import torch
import torch.nn as nn
import vllm_hpu_extension.environment as environment
from vllm_hpu_extension.bucketing import HPUBucketingContext
from vllm_hpu_extension.flags import enabled_flags
Expand All @@ -44,6 +43,7 @@
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
Expand All @@ -57,7 +57,9 @@
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceData, SequenceGroupMetadata,
SequenceOutput)
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import (bind_kv_cache, is_fake_hpu, is_pin_memory_available,
make_mrope_positions_tensor_with_pad,
make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
Expand Down Expand Up @@ -423,7 +425,7 @@ def _prepare_cos_sin(self, positions):
else:
raise AttributeError(
"The module at the end of the path does not have \
a 'prepare_cos_sin' method.")
a 'prepare_cos_sin' method.")

def forward(self, *args, **kwargs):
kwargs = kwargs.copy()
Expand All @@ -439,7 +441,9 @@ def forward(self, *args, **kwargs):
input_ids.device, self.dtype)
if 'lora_mask' in kwargs:
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
if self.layer_names is not None:
model_config = getattr(self.model, "config", None)
model_is_mrope = uses_mrope(model_config)
if self.layer_names is not None and not model_is_mrope:
self._prepare_cos_sin(kwargs['positions'])

with set_forward_context(kwargs['attn_metadata'], self.vllm_config,
Expand Down Expand Up @@ -759,6 +763,11 @@ def _set_gc_threshold(self) -> None:
self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP',
'false').lower() == 'true'

@property
def model_is_mrope(self) -> bool:
config = self.model_config.hf_config
return uses_mrope(config)

def load_model(self) -> None:
import habana_frameworks.torch.core as htcore
if self.model_config.quantization == 'inc' or \
Expand Down Expand Up @@ -878,11 +887,12 @@ def _add_dummy_seq(self, seq_group_metadata_list, is_prompt):

def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
return htorch.hpu.wrap_in_hpu_graph(
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
HpuModelAdapter(*args, **kwargs),
disable_tensor_cache=True,
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(
*args, **kwargs)

def get_model(self) -> nn.Module:
def get_model(self) -> torch.nn.Module:
if isinstance(self.model, HpuModelAdapter):
return self.model.model
return self.model
Expand Down Expand Up @@ -929,12 +939,34 @@ def _check_config(self, batch_size, seq_len, attn_metadata, warmup_mode):
"Configuration: (%s, %s, %s, %s) was not warmed-up!", phase,
batch_size, seq_len, num_blocks)

def _get_mrope_positions_and_delta(self, seq_data, mm_kwargs, context_len):
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'.")
hf_config = self.model_config.hf_config
token_ids = seq_data.get_token_ids()
mrope_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
)
assert mrope_positions is not None
return mrope_positions, mrope_position_delta

def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> PreparePromptMetadata:
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_mrope_positions: List[List[List[int]]] = []
slot_mapping: List[List[int]] = []
lora_index_mapping: List[List[int]] = []
lora_prompt_mapping: List[List[int]] = []
Expand Down Expand Up @@ -1005,6 +1037,7 @@ def _prepare_prompt(
# is always the first token in the sequence.
input_positions.append(list(range(context_len, seq_len)))

seq_data_mrope_positions: Optional[List[List[int]]] = None
if seq_group_metadata.multi_modal_data:
positions = input_positions[0]
mm_data, placeholder_maps = MultiModalPlaceholderMap \
Expand All @@ -1019,12 +1052,29 @@ def _prepare_prompt(
seq_group_metadata.mm_processor_kwargs,
)

# special processing for mrope position deltas.
if self.model_is_mrope:
mrope_positions, mrope_position_delta = \
self._get_mrope_positions_and_delta(
seq_data=seq_data,
mm_kwargs=mm_kwargs,
context_len=context_len)
assert mrope_positions is not None
seq_data.mrope_position_delta = mrope_position_delta
seq_data_mrope_positions = [[] for _ in range(3)]
for idx in range(3):
seq_data_mrope_positions[idx] \
.extend(mrope_positions[idx])

multi_modal_kwargs_list.append(mm_kwargs)

for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)

input_mrope_positions.append(
seq_data_mrope_positions) # type: ignore

if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
Expand Down Expand Up @@ -1110,11 +1160,18 @@ def _prepare_prompt(
dtype=torch.long,
device='cpu')

input_positions = make_tensor_with_pad(input_positions,
max_len=max_prompt_len,
pad=0,
dtype=torch.long,
device='cpu')
if self.model_is_mrope:
input_positions = \
make_mrope_positions_tensor_with_pad(input_positions=input_positions,
input_mrope_positions=input_mrope_positions,
max_prompt_len=max_prompt_len,
pad=0)
else:
input_positions = make_tensor_with_pad(input_positions,
max_len=max_prompt_len,
pad=0,
dtype=torch.long,
device='cpu')

slot_mapping = make_tensor_with_pad(slot_mapping,
max_len=max_prompt_len,
Expand Down Expand Up @@ -1196,6 +1253,7 @@ def _prepare_decode(
) -> PrepareDecodeMetadata:
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
slot_mapping: List[List[int]] = []
seq_lens: List[int] = []
encoder_seq_lens: List[int] = []
Expand Down Expand Up @@ -1243,6 +1301,18 @@ def _prepare_decode(
position = seq_len - 1
input_positions.append([position])

if self.model_is_mrope:
if seq_data.mrope_position_delta is not None:
pos_for_mrope = MRotaryEmbedding \
.get_next_input_positions(
seq_data.mrope_position_delta,
seq_data.get_num_computed_tokens(),
seq_len)
else:
pos_for_mrope = [[position]] * 3
for idx in range(3):
input_mrope_positions[idx].extend(pos_for_mrope[idx])

seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
seq_lens.append(seq_len)
Expand Down Expand Up @@ -1278,9 +1348,10 @@ def _prepare_decode(
real_batch_size = len(seq_group_metadata_list)
input_tokens = output[:real_batch_size].clone()

input_positions = torch.tensor(input_positions,
dtype=torch.long,
device='cpu')
input_positions = torch.tensor(
input_mrope_positions if self.model_is_mrope else input_positions,
dtype=torch.long,
device='cpu')

num_decode_tokens = len(seq_lens)

Expand Down
Loading