diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 10ff71e57578e..99fbda314f6d3 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( _normalize_quant_group_shape, scaled_dequantize) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear) + CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -38,7 +38,7 @@ def apply_w8a8_block_fp8_linear( weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, - cutlass_block_fp8_supported: bool = True, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, ) -> torch.Tensor: assert input_scale is None # View input as 2D matrix for fp8 methods @@ -85,12 +85,14 @@ def apply_w8a8_block_fp8_linear( # `apply_fp8_linear` # NOTE(lucas): this is quite messy, we should think through this more formally def apply_fp8_linear_generic( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_group_shape: Tuple[int, int], - weight_group_shape: Tuple[int, int], - input_scale: Optional[torch.Tensor] = None, # static scale if one + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_group_shape: Tuple[int, int], + weight_group_shape: Tuple[int, int], + input_scale: Optional[torch.Tensor] = None, # static scale if one + cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, ) -> torch.Tensor: # View input as 2D matrix for fp8 methods input = input.view(-1, input.shape[-1]) @@ -105,14 +107,18 @@ def is_dim_blocked(dim, shape, group_shape): if is_dim_blocked(0, weight.shape, weight_group_shape[0])\ and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\ input_group_shape == (1, weight_group_shape[1]): - return apply_w8a8_block_fp8_linear(input, weight, - list(weight_group_shape), - weight_scale) + return apply_w8a8_block_fp8_linear( + input, + weight, + list(weight_group_shape), + weight_scale, + cutlass_block_fp8_supported=cutlass_block_fp8_supported) else: # Despite having linear in the it doesn't conform to # `torch.nn.functional.linear` which is defined as `input @ weight.T` # so we explicitly transpose the weight matrix here return apply_fp8_linear(input, weight.T, weight_scale.T, + cutlass_fp8_supported=cutlass_fp8_supported, use_per_token_if_dynamic=\ (input_group_shape == (1, input.shape[1]))) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 3fd88e8754a59..dedeb0c296bd4 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -42,6 +42,10 @@ def cutlass_block_fp8_supported() -> bool: return ops.cutlass_scaled_mm_supports_block_fp8(capability) +CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported() +CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported() + + def per_tensor_dequantize( tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]) -> torch.Tensor: @@ -109,7 +113,7 @@ def apply_fp8_linear( input_scale: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, - cutlass_fp8_supported: bool = True, + cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED, use_per_token_if_dynamic: bool = False, ) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant.