@@ -53,7 +53,6 @@ def group_broadcast(t, shape):
53
53
54
54
55
55
# Quantize assuming once scale per group of elements with shape group_shape,
56
- # Defaults to quantizing to correct platform specific float8 type
57
56
# example group shapes:
58
57
# * (-1, -1) for per-tensor quantization
59
58
# * (1, -1) for per-row quantization
@@ -64,14 +63,14 @@ def group_broadcast(t, shape):
64
63
def scaled_quantize (
65
64
x : torch .Tensor ,
66
65
group_shape : Tuple [int , int ],
67
- dtype : torch .dtype ,
66
+ tgt_dtype : torch .dtype ,
68
67
) -> Tuple [torch .Tensor , torch .Tensor ]:
69
68
group_shape = _normalize_quant_group_shape (x , group_shape )
70
- assert dtype .is_floating_point , \
69
+ assert tgt_dtype .is_floating_point , \
71
70
"currently `scaled_quantize` only supports floating point dtypes " \
72
71
"but could be extended to support other dtypes"
73
72
74
- finfo = torch .finfo (dtype )
73
+ finfo = torch .finfo (tgt_dtype )
75
74
76
75
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
77
76
assert x .ndim == 2
@@ -97,7 +96,7 @@ def scaled_quantize(
97
96
.permute (0 , 2 , 1 , 3 )\
98
97
.reshape (x .shape )
99
98
100
- return x_scl_sat .to (dtype ).contiguous (), scale .float ().reciprocal ()
99
+ return x_scl_sat .to (tgt_dtype ).contiguous (), scale .float ().reciprocal ()
101
100
102
101
103
102
# inverses `scaled_quantize`
0 commit comments