Skip to content

Commit

Permalink
Undo the changes in the beam search functions
Browse files Browse the repository at this point in the history
  • Loading branch information
bryant1410 committed Feb 1, 2022
1 parent 393c527 commit 066b52c
Showing 1 changed file with 15 additions and 22 deletions.
37 changes: 15 additions & 22 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,14 +1989,12 @@ def beam_search(
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)

next_token_scores_processed = logits_processor(input_ids, next_token_logits)

next_token_scores_processed = nn.functional.log_softmax(
next_token_scores_processed, dim=-1
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)

next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores_processed)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
Expand Down Expand Up @@ -2307,14 +2305,12 @@ def beam_sample(
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)

next_token_scores_processed = logits_processor(input_ids, next_token_logits)

next_token_scores_processed = nn.functional.log_softmax(
next_token_scores_processed, dim=-1
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)

next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores_processed)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
next_token_scores = logits_warper(input_ids, next_token_scores)

# Store scores, attentions and hidden_states when required
Expand Down Expand Up @@ -2659,19 +2655,16 @@ def group_beam_search(
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)

next_token_scores_processed = logits_processor(
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)

next_token_scores_processed = nn.functional.log_softmax(
next_token_scores_processed, dim=-1
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * group_size, vocab_size)
vocab_size = next_token_scores_processed.shape[-1]
vocab_size = next_token_scores.shape[-1]

next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(
next_token_scores_processed
next_token_scores_processed = logits_processor(
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)

if output_scores:
processed_score[batch_group_indices] = next_token_scores_processed
Expand Down

0 comments on commit 066b52c

Please # to comment.