diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index eae275b5ad954..07a71109f4895 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -294,8 +294,7 @@ def __init__( self.norm = PPMissingLayer() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) - return inputs_embeds + return self.embed_tokens(input_ids) def forward( self, @@ -317,7 +316,7 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - hidden_states = hidden_states * self.config.embedding_multiplier + hidden_states *= self.config.embedding_multiplier for i in range(self.start_layer, self.end_layer): layer = self.layers[i] @@ -426,10 +425,10 @@ def forward( return model_output def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) - logits = logits / self.config.logits_scaling + logits /= self.config.logits_scaling return logits def sample(