Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[AMD][ROCm] Enable DeepSeek model on ROCm #12662

Merged
merged 6 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/kernels/test_rocm_attention_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import patch

import pytest
import torch

from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms.rocm import RocmPlatform


@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()


def test_selector(monkeypatch):
"""Test that the attention selector for ROCm.
"""
override_backend_env_variable(monkeypatch, "ROCM_FLASH")

with patch("vllm.attention.selector.current_platform", RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert backend.get_name() == "ROCM_FLASH"
# mla test for deepseek related
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True)
assert backend.get_name() == "TRITON_MLA"
9 changes: 9 additions & 0 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
return model_runner


def test_deepseek_mla_attn_backend_module():
model_runner = _create_model_runner(
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
trust_remote_code=True,
enable_chunked_prefill=False,
)
assert model_runner.attn_backend.__name__ == "TritonMLABackend"


@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
model_runner = _create_model_runner(
Expand Down
6 changes: 5 additions & 1 deletion vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize, scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func

try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
except ImportError:
from flash_attn import flash_attn_varlen_func


@dataclass
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ def apply_w8a8_block_fp8_linear(

shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
and weight.shape[1] % 128 == 0)
if current_platform.is_rocm():
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
input_2d.shape[:-1])[::-1]
scale_b_shape = (weight_scale.view(-1, 1)
if weight_scale.dim() <= 1 else weight_scale.T).shape
ar, ac = scale_a_shape
br, bc = scale_b_shape
if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0])
or br not in (1, weight.shape[0])):
shape_supported_by_cutlass = False
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
Expand Down
3 changes: 3 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class RocmPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
if use_mla:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
Expand Down