Skip to content

Commit

Permalink
try forcing float64
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed Oct 17, 2024
1 parent 581121a commit 9694e9c
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions test/torch/distribution/test_negative_binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,14 @@ def test_custom_neg_bin_cdf(total_count, probs, value):
assert np.allclose(torch_cdf, scipy_cdf)


@pytest.mark.skipif(
sys.version_info.major == 3 and sys.version_info.minor == 9, reason="test fails on python 3.9"
)
@pytest.mark.parametrize("probs", [0.1, 0.5, 0.8])
@pytest.mark.parametrize("total_count", [3, 7, 100])
@pytest.mark.parametrize("value", [0.1, 0.5, 0.9])
def test_custom_neg_bin_icdf(total_count, probs, value):
torch_dist = NegativeBinomial(total_count=total_count, probs=probs)
scipy_dist = torch_dist.scipy_nbinom

torch_icdf = torch_dist.icdf(torch.as_tensor(value)).numpy()
torch_icdf = torch_dist.icdf(torch.as_tensor(value, dtype=torch.float64)).numpy()
scipy_icdf = scipy_dist.ppf(np.asarray(value))

assert np.allclose(torch_icdf, scipy_icdf)

0 comments on commit 9694e9c

Please # to comment.