Skip to content

Commit

Permalink
Fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
bryant1410 committed Mar 28, 2022
1 parent e255d19 commit b9bd8a1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tests/generation/test_generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,8 +542,9 @@ def test_exponential_decay_length_penalty(self):
def test_normalization_warper(self):
input_ids = None

scores = torch.tensor([[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device,
dtype=torch.float)
scores = torch.tensor(
[[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device, dtype=torch.float
)

normalization_warper = NormalizationLogitsWarper()
normalized_scores = normalization_warper(input_ids, scores).exp()
Expand Down

0 comments on commit b9bd8a1

Please # to comment.