diff --git a/tests/generation/test_generation_logits_process.py b/tests/generation/test_generation_logits_process.py index af62d85d23c96a..d59ed49149ddde 100644 --- a/tests/generation/test_generation_logits_process.py +++ b/tests/generation/test_generation_logits_process.py @@ -542,8 +542,9 @@ def test_exponential_decay_length_penalty(self): def test_normalization_warper(self): input_ids = None - scores = torch.tensor([[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device, - dtype=torch.float) + scores = torch.tensor( + [[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device, dtype=torch.float + ) normalization_warper = NormalizationLogitsWarper() normalized_scores = normalization_warper(input_ids, scores).exp()