From e255d19ee3d7aa93c6e4dbc67f226026874c559d Mon Sep 17 00:00:00 2001 From: Santiago Castro Date: Tue, 1 Feb 2022 16:31:10 -0300 Subject: [PATCH] Add a test --- .../generation/test_generation_logits_process.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/generation/test_generation_logits_process.py b/tests/generation/test_generation_logits_process.py index b95110d0e06b15..af62d85d23c96a 100644 --- a/tests/generation/test_generation_logits_process.py +++ b/tests/generation/test_generation_logits_process.py @@ -37,6 +37,7 @@ MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, + NormalizationLogitsWarper, PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, @@ -537,3 +538,18 @@ def test_exponential_decay_length_penalty(self): scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id] ).all() ) + + def test_normalization_warper(self): + input_ids = None + + scores = torch.tensor([[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device, + dtype=torch.float) + + normalization_warper = NormalizationLogitsWarper() + normalized_scores = normalization_warper(input_ids, scores).exp() + + ones = torch.ones(scores.shape[0], device=torch_device, dtype=torch.float) + self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones)) + + self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1))) +