Skip to content

Commit c9d72cb

Browse files
more cleanup
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent 3cdd2ce commit c9d72cb

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

vllm/model_executor/layers/quantization/utils/quant_utils.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def group_broadcast(t, shape):
5353

5454

5555
# Quantize assuming once scale per group of elements with shape group_shape,
56-
# Defaults to quantizing to correct platform specific float8 type
5756
# example group shapes:
5857
# * (-1, -1) for per-tensor quantization
5958
# * (1, -1) for per-row quantization
@@ -64,14 +63,14 @@ def group_broadcast(t, shape):
6463
def scaled_quantize(
6564
x: torch.Tensor,
6665
group_shape: Tuple[int, int],
67-
dtype: torch.dtype,
66+
tgt_dtype: torch.dtype,
6867
) -> Tuple[torch.Tensor, torch.Tensor]:
6968
group_shape = _normalize_quant_group_shape(x, group_shape)
70-
assert dtype.is_floating_point, \
69+
assert tgt_dtype.is_floating_point, \
7170
"currently `scaled_quantize` only supports floating point dtypes " \
7271
"but could be extended to support other dtypes"
7372

74-
finfo = torch.finfo(dtype)
73+
finfo = torch.finfo(tgt_dtype)
7574

7675
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
7776
assert x.ndim == 2
@@ -97,7 +96,7 @@ def scaled_quantize(
9796
.permute(0, 2, 1, 3)\
9897
.reshape(x.shape)
9998

100-
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
99+
return x_scl_sat.to(tgt_dtype).contiguous(), scale.float().reciprocal()
101100

102101

103102
# inverses `scaled_quantize`

0 commit comments

Comments
 (0)