Skip to content

Commit

Permalink
Add mup_embedding_multiplier
Browse files Browse the repository at this point in the history
  • Loading branch information
adk9 committed Jun 11, 2024
1 parent 0ce02a9 commit 230615b
Showing 1 changed file with 8 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def positional_embedding_type(self) -> PositionalEmbeddingType:
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
return RotateHalfConfig(theta_base=self._config.rope_embedding_base)

@property
def mup_embedding_multiplier(self) -> float:
return 10.0


"""
Forward implementations
"""
Expand All @@ -127,6 +132,9 @@ def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor:
if embed.shape[-1] != self.model_dim:
raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}")

if self.mup_embedding_multiplier > 0.0:
embed = embed * self.mup_embedding_multiplier

return embed

def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor,
Expand Down

0 comments on commit 230615b

Please # to comment.