diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 72d549e597df5..e40f282299685 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -16,29 +16,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - using GroupShape = std::array; - int M = a.size(0), N = b.size(1), K = a.size(1); - GroupShape a_scale_group_shape = [&, &s = a_scales]() -> GroupShape { - if (s.numel() == 1) return {M, K}; // tensor-wise - if (s.dim() == 2) - return {ceil_div(a.size(0), s.size(0)), ceil_div(a.size(1), s.size(1))}; - TORCH_CHECK(false, "Unsupported scale shape for scale_a"); - }(); - - GroupShape b_scale_group_shape = [&, &s = b_scales]() -> GroupShape { - if (s.numel() == 1) return {K, N}; // tensor-wise - if (s.dim() == 2) - return {ceil_div(b.size(0), s.size(0)), ceil_div(b.size(1), s.size(1))}; - TORCH_CHECK(false, "Unsupported scale shape for scale_b"); - }(); - - if ((a_scale_group_shape == GroupShape{M, K} || - a_scale_group_shape == GroupShape{1, K}) && - (b_scale_group_shape == GroupShape{K, N} || - b_scale_group_shape == GroupShape{K, 1})) { - // "standard per-tensor/per-token/per-channel" scaling + if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && + (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { + // Standard per-tensor/per-token/per-channel scaling TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); if (a.dtype() == torch::kFloat8_e4m3fn) { vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias); @@ -46,25 +28,32 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, TORCH_CHECK(a.dtype() == torch::kInt8); vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias); } - } else if (a_scale_group_shape == GroupShape{1, 128} && - b_scale_group_shape == GroupShape{128, 128}) { + } else { + using GroupShape = std::array; + auto make_group_shape = [](torch::Tensor const& x, + torch::Tensor const& s) -> GroupShape { + TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); + return {ceil_div(x.size(0), s.size(0)), ceil_div(x.size(1), s.size(1))}; + }; + + GroupShape a_scale_group_shape = make_group_shape(a, a_scales); + GroupShape b_scale_group_shape = make_group_shape(b, b_scales); + // 1x128 per-token group scales for activations // 128x128 blockwise scales for weights - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn && - b.dtype() == torch::kFloat8_e4m3fn, - "Currently only FP8 is supported for A group shape 1x128 and " - "B group shape 128x128"); - TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); - - vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales); - } else { - TORCH_CHECK(false, - "Unsupported scale group shapes for CUTLASS 3.x GEMM.\n " - "a_scale_group_shape must be [1, 128], got: [", + TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && + b_scale_group_shape == GroupShape{128, 128} && + a.dtype() == torch::kFloat8_e4m3fn && + b.dtype() == torch::kFloat8_e4m3fn), + "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" + "a_scale_group_shape must be [1, 128]. Got: [", a_scale_group_shape[0], ", ", a_scale_group_shape[1], "]\n" - "b_scale_group_shape must be [128, 128], got: [", + "b_scale_group_shape must be [128, 128]. Got: [", b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); + TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); + + vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales); } }