Skip to content

Commit

Permalink
fix: update bmm fp8 test (#487)
Browse files Browse the repository at this point in the history
fp8 scale per tensor

ref sgl-project/sglang#1285
  • Loading branch information
zhyncs authored Sep 1, 2024
1 parent 77bff3f commit 45eac04
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
5 changes: 3 additions & 2 deletions python/flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ def bmm_fp8(
>>> import flashinfer
>>> def to_float8(x, dtype=torch.float8_e4m3fn):
... finfo = torch.finfo(dtype)
... abs_max = x.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1e-12)
... scale = finfo.max / abs_max
... min_val, max_val = x.aminmax()
... amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
... scale = finfo.max / amax
... x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
... return x_scl_sat.to(dtype), scale.float().reciprocal()
>>>
Expand Down
9 changes: 5 additions & 4 deletions python/tests/test_bmm_fp8.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import pytest
import torch
import torch.nn.functional as F

from flashinfer import bmm_fp8


def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
abs_max = x.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1e-12)
scale = finfo.max / abs_max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()

Expand All @@ -32,9 +34,8 @@ def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype):
bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res)

reference = torch.bmm(input, mat2)

cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
assert cos_sim > 0.98
assert cos_sim > 0.99


if __name__ == "__main__":
Expand Down

0 comments on commit 45eac04

Please # to comment.