diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index b66a5011..0bd2b373 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -241,8 +241,9 @@ def bmm_fp8( >>> import flashinfer >>> def to_float8(x, dtype=torch.float8_e4m3fn): ... finfo = torch.finfo(dtype) - ... abs_max = x.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1e-12) - ... scale = finfo.max / abs_max + ... min_val, max_val = x.aminmax() + ... amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + ... scale = finfo.max / amax ... x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) ... return x_scl_sat.to(dtype), scale.float().reciprocal() >>> diff --git a/python/tests/test_bmm_fp8.py b/python/tests/test_bmm_fp8.py index 54eb506c..80a89e12 100644 --- a/python/tests/test_bmm_fp8.py +++ b/python/tests/test_bmm_fp8.py @@ -1,13 +1,15 @@ import pytest import torch import torch.nn.functional as F + from flashinfer import bmm_fp8 def to_float8(x, dtype=torch.float8_e4m3fn): finfo = torch.finfo(dtype) - abs_max = x.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1e-12) - scale = finfo.max / abs_max + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) return x_scl_sat.to(dtype), scale.float().reciprocal() @@ -32,9 +34,8 @@ def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype): bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res) reference = torch.bmm(input, mat2) - cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) - assert cos_sim > 0.98 + assert cos_sim > 0.99 if __name__ == "__main__":