Skip to content

Commit

Permalink
Add a test
Browse files Browse the repository at this point in the history
  • Loading branch information
bryant1410 committed Mar 28, 2022
1 parent a3ca8aa commit e255d19
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/generation/test_generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
NormalizationLogitsWarper,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
Expand Down Expand Up @@ -537,3 +538,18 @@ def test_exponential_decay_length_penalty(self):
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
).all()
)

def test_normalization_warper(self):
input_ids = None

scores = torch.tensor([[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device,
dtype=torch.float)

normalization_warper = NormalizationLogitsWarper()
normalized_scores = normalization_warper(input_ids, scores).exp()

ones = torch.ones(scores.shape[0], device=torch_device, dtype=torch.float)
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))

self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))

0 comments on commit e255d19

Please # to comment.