diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index d7a13a9255f8de..e822609ecc0cf1 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1989,14 +1989,12 @@ def beam_search( # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) - - next_token_scores_processed = logits_processor(input_ids, next_token_logits) - - next_token_scores_processed = nn.functional.log_softmax( - next_token_scores_processed, dim=-1 + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores_processed) + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -2307,14 +2305,12 @@ def beam_sample( # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) - - next_token_scores_processed = logits_processor(input_ids, next_token_logits) - - next_token_scores_processed = nn.functional.log_softmax( - next_token_scores_processed, dim=-1 + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores_processed) + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required @@ -2659,19 +2655,16 @@ def group_beam_search( # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) - - next_token_scores_processed = logits_processor( - group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx - ) - - next_token_scores_processed = nn.functional.log_softmax( - next_token_scores_processed, dim=-1 + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 ) # (batch_size * group_size, vocab_size) - vocab_size = next_token_scores_processed.shape[-1] + vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1).expand_as( - next_token_scores_processed + next_token_scores_processed = logits_processor( + group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) if output_scores: processed_score[batch_group_indices] = next_token_scores_processed