Skip to content

Commit

Permalink
Final norm fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
shawntan committed Aug 30, 2024
1 parent d849a27 commit a339b4d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
1 change: 0 additions & 1 deletion tests/models/test_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pytest

from .utils import check_logprobs_close

TRANSFORMERS_VERSION = tuple(
map(int,
importlib.metadata.version("transformers").split(".")))
Expand Down
8 changes: 3 additions & 5 deletions vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ def forward(
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
Expand All @@ -252,7 +251,7 @@ def forward(
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
return hidden_states, residual
return hidden_states


class GraniteModel(nn.Module):
Expand Down Expand Up @@ -321,12 +320,11 @@ def forward(

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)

if not get_pp_group().is_last_rank:
Expand All @@ -335,7 +333,7 @@ def forward(
"residual": residual
})

hidden_states, _ = self.norm(hidden_states, residual)
hidden_states = self.norm(hidden_states)
return hidden_states


Expand Down

0 comments on commit a339b4d

Please # to comment.