Skip to content

Commit

Permalink
Fix float tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 1, 2024
1 parent c6d60dc commit 19cd637
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ def test_float_to_quant_float(inp, minifloat_format):
@given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format())
@jit_disabled_for_mock()
def test_scaling_impls_called_once(inp, minifloat_format):
float_scaling_impl_return = 1.
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format
scaling_impl = mock.Mock(side_effect=lambda x, y: 1.)
float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.)
float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: float_scaling_impl_return)
if exponent_bit_width == 0 or mantissa_bit_width == 0:
with pytest.raises(RuntimeError):
float_quant = FloatQuant(
Expand Down Expand Up @@ -148,7 +149,7 @@ def test_scaling_impls_called_once(inp, minifloat_format):
torch.tensor(exponent_bit_width),
torch.tensor(mantissa_bit_width),
torch.tensor(exponent_bias))
scaling_impl.assert_called_once_with(inp)
scaling_impl.assert_called_once_with(inp, float_scaling_impl_return)


@given(
Expand Down

0 comments on commit 19cd637

Please # to comment.