diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index b22994275..52819c22f 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -12,6 +12,7 @@ from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import load_quant_model_mode +from brevitas.inject.enum import RestrictValueType import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat @@ -27,7 +28,9 @@ BATCH = 1 REFERENCE_SCALES = { 'int_quant': (0.00935234408825635910, 0.01362917013466358185), - 'fp_quant': (0.00249395845457911491, 0.00363444536924362183)} + 'fp_quant': (0.00249395845457911491, 0.00363444536924362183), + 'int_po2_quant': (0.015625, 0.015625), + 'fp_po2_quant': (0.001953125, 0.00390625),} REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010], [1.4573, -0.9074, -0.2708]]) @@ -75,7 +78,15 @@ def forward(self, x): assert torch.allclose(expected_scale, scale) -QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3ActPerTensorFloat} +class Fp8e4m3ActPerTensorFixedPoint(Fp8e4m3ActPerTensorFloat): + restrict_scaling_type = RestrictValueType.POWER_OF_TWO + + +QUANTS = { + 'int_quant': Int8ActPerTensorFloat, + 'fp_quant': Fp8e4m3ActPerTensorFloat, + 'int_po2_quant': Int8ActPerTensorFixedPoint, + 'fp_po2_quant': Fp8e4m3ActPerTensorFixedPoint} @pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys())