diff --git a/tests/multimodal/test_base.py b/tests/multimodal/test_base.py index f19a0f33fe067..e9562d2048f06 100644 --- a/tests/multimodal/test_base.py +++ b/tests/multimodal/test_base.py @@ -81,3 +81,15 @@ def test_multimodal_input_batch_multiple_batchable_lists(): result, {"image": torch.stack([torch.stack([a, b]), torch.stack([c, d])])}) + + +def test_multimodal_input_batch_mixed_stacking_depths(): + a = torch.rand([1, 2, 3]) + b = torch.rand([1, 3, 3]) + c = torch.rand([1, 4, 3]) + + result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}]) + assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]}) + + result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}]) + assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]}) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 6e7ee511bf27f..16565e1467e8f 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,6 @@ from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, Union, overload) -import numpy as np import torch import torch.nn as nn from torch.func import functional_call @@ -96,12 +95,13 @@ def flatten_bn( def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: """ - Recursively concatenates NestedTensors along any heterogeneously sized - dimensions. + Recursively flattens and concatenates NestedTensors on all but the last + dimension. """ if isinstance(embeddings, torch.Tensor): - return embeddings + # Flatten all but the last dimension. + return embeddings.flatten(0, -2) return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) @@ -136,15 +136,13 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, assert isinstance(num_expected_tokens, int) flattened = _flatten_embeddings(multimodal_embeddings) - *dims, embed_dim = flattened.shape - num_multimodal_embeddings = np.prod(dims) - if num_multimodal_embeddings != num_expected_tokens: + if flattened.shape[0] != num_expected_tokens: expr = _embedding_count_expression(multimodal_embeddings) raise ValueError( - f"Attempted to assign {expr} = {num_multimodal_embeddings} " + f"Attempted to assign {expr} = {flattened.shape[0]} " f"multimodal tokens to {num_expected_tokens} placeholders") - inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim) + inputs_embeds[mask] = flattened return inputs_embeds diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index c02e61596927a..17ef9938d0572 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -54,8 +54,8 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: return nested_tensors stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] - if is_list_of(stacked, list): - # Do not stack nested lists + if not is_list_of(stacked, torch.Tensor, check="all"): + # Only tensors (not lists) can be stacked. return stacked tensors_ = cast(List[torch.Tensor], stacked)