Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Bugfix] Better FP8 supported defaults #12796

Merged
merged 4 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
_normalize_quant_group_shape, scaled_dequantize)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
from vllm.platforms import current_platform

logger = init_logger(__name__)
Expand All @@ -38,7 +38,7 @@ def apply_w8a8_block_fp8_linear(
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = True,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
Expand Down Expand Up @@ -85,12 +85,14 @@ def apply_w8a8_block_fp8_linear(
# `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally
def apply_fp8_linear_generic(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_group_shape: Tuple[int, int],
weight_group_shape: Tuple[int, int],
input_scale: Optional[torch.Tensor] = None, # static scale if one
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_group_shape: Tuple[int, int],
weight_group_shape: Tuple[int, int],
input_scale: Optional[torch.Tensor] = None, # static scale if one
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input = input.view(-1, input.shape[-1])
Expand All @@ -105,14 +107,18 @@ def is_dim_blocked(dim, shape, group_shape):
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
input_group_shape == (1, weight_group_shape[1]):
return apply_w8a8_block_fp8_linear(input, weight,
list(weight_group_shape),
weight_scale)
return apply_w8a8_block_fp8_linear(
input,
weight,
list(weight_group_shape),
weight_scale,
cutlass_block_fp8_supported=cutlass_block_fp8_supported)
else:
# Despite having linear in the it doesn't conform to
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
# so we explicitly transpose the weight matrix here
return apply_fp8_linear(input, weight.T, weight_scale.T,
cutlass_fp8_supported=cutlass_fp8_supported,
use_per_token_if_dynamic=\
(input_group_shape == (1, input.shape[1])))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def cutlass_block_fp8_supported() -> bool:
return ops.cutlass_scaled_mm_supports_block_fp8(capability)


CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported()
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()


def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
Expand Down Expand Up @@ -109,7 +113,7 @@ def apply_fp8_linear(
input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_fp8_supported: bool = True,
cutlass_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_per_token_if_dynamic: bool = False,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
Expand Down