Skip to content

Commit

Permalink
[AMD][ROCm] Enable DeepSeek model on ROCm (vllm-project#12662)
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
  • Loading branch information
2 people authored and panf2333 committed Feb 18, 2025
1 parent 5937fbe commit 659d02d
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 1 deletion.
31 changes: 31 additions & 0 deletions tests/kernels/test_rocm_attention_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import patch

import pytest
import torch

from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms.rocm import RocmPlatform


@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()


def test_selector(monkeypatch):
"""Test that the attention selector for ROCm.
"""
override_backend_env_variable(monkeypatch, "ROCM_FLASH")

with patch("vllm.attention.selector.current_platform", RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert backend.get_name() == "ROCM_FLASH"
# mla test for deepseek related
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True)
assert backend.get_name() == "TRITON_MLA"
9 changes: 9 additions & 0 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
return model_runner


def test_deepseek_mla_attn_backend_module():
model_runner = _create_model_runner(
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
trust_remote_code=True,
enable_chunked_prefill=False,
)
assert model_runner.attn_backend.__name__ == "TritonMLABackend"


@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
model_runner = _create_model_runner(
Expand Down
6 changes: 5 additions & 1 deletion vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize, scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func

try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
except ImportError:
from flash_attn import flash_attn_varlen_func


@dataclass
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ def apply_w8a8_block_fp8_linear(

shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
and weight.shape[1] % 128 == 0)
if current_platform.is_rocm():
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
input_2d.shape[:-1])[::-1]
scale_b_shape = (weight_scale.view(-1, 1)
if weight_scale.dim() <= 1 else weight_scale.T).shape
ar, ac = scale_a_shape
br, bc = scale_b_shape
if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0])
or br not in (1, weight.shape[0])):
shape_supported_by_cutlass = False
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
Expand Down
3 changes: 3 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class RocmPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
if use_mla:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
Expand Down

0 comments on commit 659d02d

Please # to comment.