Skip to content

Commit

Permalink
update decoder_vocab_size when resizing embeds (huggingface#16700)
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj authored and elusenji committed Jun 12, 2022
1 parent b7c6e31 commit 7a771da
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,11 +1280,9 @@ def __init__(self, config: MarianConfig):
super().__init__(config)
self.model = MarianModel(config)

self.target_vocab_size = (
config.vocab_size if config.share_encoder_decoder_embeddings else config.decoder_vocab_size
)
self.register_buffer("final_logits_bias", torch.zeros((1, self.target_vocab_size)))
self.lm_head = nn.Linear(config.d_model, self.target_vocab_size, bias=False)
target_vocab_size = config.vocab_size if config.share_encoder_decoder_embeddings else config.decoder_vocab_size
self.register_buffer("final_logits_bias", torch.zeros((1, target_vocab_size)))
self.lm_head = nn.Linear(config.d_model, target_vocab_size, bias=False)

# Initialize weights and apply final processing
self.post_init()
Expand All @@ -1306,6 +1304,10 @@ def _resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings)

# update config.decoder_vocab_size if embeddings are tied
if self.config.share_encoder_decoder_embeddings:
self.config.decoder_vocab_size = new_num_tokens

# if word embeddings are not tied, make sure that lm head is resized as well
if (
self.config.share_encoder_decoder_embeddings
Expand Down Expand Up @@ -1451,7 +1453,7 @@ def forward(
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(lm_logits.view(-1, self.target_vocab_size), labels.view(-1))
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))

if not return_dict:
output = (lm_logits,) + outputs[1:]
Expand Down

0 comments on commit 7a771da

Please # to comment.