Skip to content

Commit 20b2d90

Browse files
alex-jw-brookskerthcet
authored andcommitted
[Bugfix] Fix Positive Feature Layers in Llava Models (vllm-project#13514)
Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>
1 parent 3f5fe64 commit 20b2d90

File tree

6 files changed

+44
-9
lines changed

6 files changed

+44
-9
lines changed

tests/models/test_vision.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import torch
5+
6+
from vllm.model_executor.models.vision import resolve_visual_encoder_outputs
7+
8+
9+
@pytest.mark.parametrize(
10+
("feature_sample_layers", "num_layers_loaded", "max_possible_layers",
11+
"expected_features"),
12+
[
13+
# All layers loaded
14+
([1, 10], 10, 10, [1, 10]),
15+
([-10, -1], 10, 10, [1, 10]),
16+
# Some layers not loaded
17+
([1, 10], 10, 20, [1, 10]),
18+
([-20, -11], 10, 20, [1, 10]),
19+
])
20+
def test_resolve_visual_encoder_outputs(feature_sample_layers,
21+
num_layers_loaded, max_possible_layers,
22+
expected_features):
23+
"""
24+
Test that offsets are correctly handled for vision feature layers.
25+
"""
26+
encoder_outputs = [
27+
torch.tensor([idx]) for idx in range(num_layers_loaded + 1)
28+
]
29+
output_tensor = resolve_visual_encoder_outputs(
30+
encoder_outputs=encoder_outputs,
31+
feature_sample_layers=feature_sample_layers,
32+
post_layer_norm=None,
33+
max_possible_layers=max_possible_layers)
34+
assert torch.equal(torch.tensor(expected_features), output_tensor)

vllm/model_executor/models/clip.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def __init__(
251251
def forward(
252252
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
253253
) -> Union[torch.Tensor, list[torch.Tensor]]:
254-
hidden_states_pool = []
254+
hidden_states_pool = [inputs_embeds]
255255
hidden_states = inputs_embeds
256256

257257
for encoder_layer in self.layers:

vllm/model_executor/models/llava.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
428428

429429

430430
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
431-
"""Given an signed vision feature layer, get the number of hidden layers
431+
"""Given a signed vision feature layer, get the number of hidden layers
432432
needed to leverage it.
433433
434434
Args:
@@ -438,7 +438,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
438438
"""
439439
if feature_layer_index < 0:
440440
return num_hidden_layers + feature_layer_index + 1
441-
return feature_layer_index + 1
441+
return feature_layer_index
442442

443443

444444
def init_vision_tower_for_llava(

vllm/model_executor/models/pixtral.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ def forward(
969969
position_embeddings: torch.Tensor,
970970
return_all_hidden_states: bool,
971971
) -> torch.Tensor:
972-
hidden_states_pool = []
972+
hidden_states_pool = [x]
973973

974974
for layer in self.layers:
975975
x = layer(x, attention_mask, position_embeddings)

vllm/model_executor/models/siglip.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def forward(
378378
inputs_embeds: torch.Tensor,
379379
return_all_hidden_states: bool,
380380
) -> Union[torch.Tensor, list[torch.Tensor]]:
381-
hidden_states_pool = []
381+
hidden_states_pool = [inputs_embeds]
382382
hidden_states = inputs_embeds
383383

384384
for encoder_layer in self.layers:

vllm/model_executor/models/vision.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,11 @@ def resolve_visual_encoder_outputs(
132132
# Get the hidden states corresponding to the layer indices.
133133
# Negative values are relative to the full visual encoder,
134134
# so offset them depending on how many layers were loaded.
135-
# NOTE: this assumes that encoder_outputs contains a list
136-
# of hidden states in the same order as the encoder layers
137-
# that produced them.
138-
offset = max_possible_layers - len(encoder_outputs)
135+
# NOTE: this assumes that encoder_outputs is a list containing
136+
# the inputs to the visual encoder, followed by the hidden states
137+
# of each layer.
138+
num_loaded_layers = len(encoder_outputs) - 1
139+
offset = max_possible_layers - num_loaded_layers
139140
hs_pool = [
140141
encoder_outputs[layer_idx]
141142
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]

0 commit comments

Comments
 (0)