diff --git a/tests/models/test_granite.py b/tests/models/test_granite.py index 2435b5dc3ff88..c1841175b40f5 100644 --- a/tests/models/test_granite.py +++ b/tests/models/test_granite.py @@ -7,7 +7,6 @@ import pytest from .utils import check_logprobs_close - TRANSFORMERS_VERSION = tuple( map(int, importlib.metadata.version("transformers").split("."))) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index d00e07dde7d1c..7c31733262331 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -235,7 +235,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -252,7 +251,7 @@ def forward( hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states * self.residual_multiplier - return hidden_states, residual + return hidden_states class GraniteModel(nn.Module): @@ -321,12 +320,11 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer( + hidden_states = layer( positions, hidden_states, kv_caches[i - self.start_layer], attn_metadata, - residual, ) if not get_pp_group().is_last_rank: @@ -335,7 +333,7 @@ def forward( "residual": residual }) - hidden_states, _ = self.norm(hidden_states, residual) + hidden_states = self.norm(hidden_states) return hidden_states