Skip to content

Commit 8bdc14a

Browse files
review comments
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent d27826d commit 8bdc14a

File tree

6 files changed

+13
-23
lines changed

6 files changed

+13
-23
lines changed

vllm/attention/layer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
use_mla: bool = False,
4545
prefix: str = "",
4646
attn_type: str = AttentionType.DECODER,
47-
**kwargs,
47+
**extra_impl_args,
4848
) -> None:
4949
super().__init__()
5050
if per_layer_sliding_window is not None:
@@ -114,7 +114,7 @@ def __init__(
114114
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
115115
alibi_slopes, sliding_window, kv_cache_dtype,
116116
blocksparse_params, logits_soft_cap, attn_type,
117-
**kwargs)
117+
**extra_impl_args)
118118
self.num_heads = num_heads
119119
self.head_size = head_size
120120
self.num_kv_heads = num_kv_heads

vllm/config.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ class ModelConfig:
165165
`logits_processors` extra completion argument. Defaults to None,
166166
which allows no processors.
167167
generation_config: Configuration parameter file for generation.
168-
disable_mla: Whether to disable MLA for DeepSeek models.
169168
override_generation_config: Override the generation config with the
170169
given config.
171170
"""
@@ -227,7 +226,6 @@ def __init__(
227226
override_pooler_config: Optional["PoolerConfig"] = None,
228227
logits_processor_pattern: Optional[str] = None,
229228
generation_config: Optional[str] = None,
230-
disable_mla: bool = False,
231229
enable_sleep_mode: bool = False,
232230
override_generation_config: Optional[Dict[str, Any]] = None,
233231
) -> None:
@@ -278,7 +276,6 @@ def __init__(
278276
self.max_logprobs = max_logprobs
279277
self.disable_sliding_window = disable_sliding_window
280278
self.skip_tokenizer_init = skip_tokenizer_init
281-
self.disable_mla = disable_mla
282279
self.enable_sleep_mode = enable_sleep_mode
283280

284281
from vllm.platforms import current_platform
@@ -748,7 +745,7 @@ def is_deepseek_mla(self) -> bool:
748745
def get_head_size(self) -> int:
749746
# TODO remove hard code
750747
if self.is_deepseek_mla:
751-
if self.should_use_mla:
748+
if self.use_mla:
752749
return self.hf_text_config.kv_lora_rank
753750
else:
754751
qk_rope_head_dim = getattr(self.hf_text_config,
@@ -815,7 +812,7 @@ def get_total_num_kv_heads(self) -> int:
815812

816813
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
817814
"""Returns the number of KV heads per GPU."""
818-
if self.should_use_mla:
815+
if self.use_mla:
819816
# When using MLA during decode it becomes MQA
820817
return 1
821818

@@ -971,8 +968,7 @@ def is_cross_encoder(self) -> bool:
971968

972969
@property
973970
def use_mla(self) -> bool:
974-
use_mla = (self.is_deepseek_mla and not self.disable_mla
975-
and not envs.VLLM_MLA_DISABLE)
971+
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
976972
return use_mla
977973

978974
def supported_runner_types(self) -> Set[RunnerType]:

vllm/engine/arg_utils.py

-5
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ class EngineArgs:
100100
kv_cache_dtype: str = 'auto'
101101
seed: int = 0
102102
max_model_len: Optional[int] = None
103-
disable_mla: bool = False
104103
# Note: Specifying a custom executor backend by passing a class
105104
# is intended for expert use only. The API may change without
106105
# notice.
@@ -932,9 +931,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
932931
type=str,
933932
default="auto",
934933
help='The worker class to use for distributed execution.')
935-
parser.add_argument('--disable-mla',
936-
action='store_true',
937-
help='Disable MLA for DeepSeek models.')
938934
parser.add_argument(
939935
"--generation-config",
940936
type=nullable_str,
@@ -1015,7 +1011,6 @@ def create_model_config(self) -> ModelConfig:
10151011
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
10161012
override_neuron_config=self.override_neuron_config,
10171013
override_pooler_config=self.override_pooler_config,
1018-
disable_mla=self.disable_mla,
10191014
logits_processor_pattern=self.logits_processor_pattern,
10201015
generation_config=self.generation_config,
10211016
override_generation_config=self.override_generation_config,

vllm/model_executor/models/deepseek_v2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def __init__(
488488
# DecoderLayers are created with `make_layers` which passes the prefix
489489
# with the layer's index.
490490
layer_idx = int(prefix.split(sep='.')[-1])
491-
if model_config.should_use_mla:
491+
if model_config.use_mla:
492492
attn_cls = DeepseekV2MLAAttention
493493
else:
494494
attn_cls = DeepseekV2Attention

vllm/worker/cache_engine.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,12 @@ def __init__(
5252
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
5353

5454
# Get attention backend.
55-
self.attn_backend = get_attn_backend(
56-
self.head_size,
57-
model_config.dtype,
58-
cache_config.cache_dtype,
59-
self.block_size,
60-
model_config.is_attention_free,
61-
use_mla=model_config.should_use_mla)
55+
self.attn_backend = get_attn_backend(self.head_size,
56+
model_config.dtype,
57+
cache_config.cache_dtype,
58+
self.block_size,
59+
model_config.is_attention_free,
60+
use_mla=model_config.use_mla)
6261

6362
# Initialize the cache.
6463
self.gpu_cache = self._allocate_kv_cache(

vllm/worker/model_runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ def __init__(
10661066
self.kv_cache_dtype,
10671067
self.block_size,
10681068
self.model_config.is_attention_free,
1069-
use_mla=self.model_config.should_use_mla,
1069+
use_mla=self.model_config.use_mla,
10701070
) if needs_attn_backend else None
10711071
if self.attn_backend:
10721072
self.attn_state = self.attn_backend.get_state_cls()(

0 commit comments

Comments
 (0)