diff --git a/tests/kernels/test_rocm_attention_selector.py b/tests/kernels/test_rocm_attention_selector.py new file mode 100644 index 0000000000000..5848dc014ca69 --- /dev/null +++ b/tests/kernels/test_rocm_attention_selector.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import patch + +import pytest +import torch + +from tests.kernels.utils import override_backend_env_variable +from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend +from vllm.platforms.rocm import RocmPlatform + + +@pytest.fixture(autouse=True) +def clear_cache(): + """Clear lru cache to ensure each test case runs without caching. + """ + _cached_get_attn_backend.cache_clear() + + +def test_selector(monkeypatch): + """Test that the attention selector for ROCm. + """ + override_backend_env_variable(monkeypatch, "ROCM_FLASH") + + with patch("vllm.attention.selector.current_platform", RocmPlatform()): + backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) + assert backend.get_name() == "ROCM_FLASH" + # mla test for deepseek related + backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, + False, True) + assert backend.get_name() == "TRITON_MLA" diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index aabe913c242e1..790699fb93cf3 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -22,6 +22,15 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: return model_runner +def test_deepseek_mla_attn_backend_module(): + model_runner = _create_model_runner( + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", + trust_remote_code=True, + enable_chunked_prefill=False, + ) + assert model_runner.attn_backend.__name__ == "TritonMLABackend" + + @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_prompt(batch_size): model_runner = _create_model_runner( diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index e8fec234c0225..722ebe3be86d4 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -25,7 +25,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_dequantize, scaled_quantize) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from vllm.vllm_flash_attn import flash_attn_varlen_func + +try: + from vllm.vllm_flash_attn import flash_attn_varlen_func +except ImportError: + from flash_attn import flash_attn_varlen_func @dataclass diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 850820f66ff90..223da98645eee 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -45,6 +45,16 @@ def apply_w8a8_block_fp8_linear( shape_supported_by_cutlass = (weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + if current_platform.is_rocm(): + scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) + + input_2d.shape[:-1])[::-1] + scale_b_shape = (weight_scale.view(-1, 1) + if weight_scale.dim() <= 1 else weight_scale.T).shape + ar, ac = scale_a_shape + br, bc = scale_b_shape + if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0]) + or br not in (1, weight.shape[0])): + shape_supported_by_cutlass = False if cutlass_block_fp8_supported and shape_supported_by_cutlass: q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1], diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 8888521631481..e1f07ea8972a4 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -77,6 +77,9 @@ class RocmPlatform(Platform): def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) -> str: + if use_mla: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if selected_backend == _Backend.ROCM_FLASH: