diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index b4f2cc89e..a5f597586 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -109,9 +109,10 @@ def test_float_to_quant_float(inp, minifloat_format): @given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format()) @jit_disabled_for_mock() def test_scaling_impls_called_once(inp, minifloat_format): + float_scaling_impl_return = 1. bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format scaling_impl = mock.Mock(side_effect=lambda x, y: 1.) - float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) + float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: float_scaling_impl_return) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( @@ -148,7 +149,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): torch.tensor(exponent_bit_width), torch.tensor(mantissa_bit_width), torch.tensor(exponent_bias)) - scaling_impl.assert_called_once_with(inp) + scaling_impl.assert_called_once_with(inp, float_scaling_impl_return) @given(