Skip to content

Commit 634eee6

Browse files
review comments
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent 31b802c commit 634eee6

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

vllm/config.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -748,8 +748,6 @@ def is_deepseek_mla(self) -> bool:
748748
def get_head_size(self) -> int:
749749
# TODO remove hard code
750750
if self.is_deepseek_mla:
751-
# FlashAttention supports only head_size 32, 64, 128, 256,
752-
# we need to pad head_size 192 to 256
753751
if self.should_use_mla:
754752
return self.hf_text_config.kv_lora_rank
755753
else:
@@ -974,7 +972,7 @@ def is_cross_encoder(self) -> bool:
974972
@property
975973
def should_use_mla(self) -> bool:
976974
use_mla = (self.is_deepseek_mla and not self.disable_mla
977-
and not envs.VLLM_DISABLE_MLA)
975+
and not envs.VLLM_MLA_DISABLE)
978976
return use_mla
979977

980978
def supported_runner_types(self) -> Set[RunnerType]:

vllm/envs.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
V_SCALE_CONSTANT: int = 100
7878
VLLM_SERVER_DEV_MODE: bool = False
7979
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
80+
VLLM_MLA_DISABLE: bool = False
8081
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
8182

8283

@@ -302,10 +303,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
302303
"VLLM_FLASHINFER_FORCE_TENSOR_CORES":
303304
lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),
304305

305-
# If set, vLLM will disable the MLA attention optimizations.
306-
"VLLM_DISABLE_MLA":
307-
lambda: bool(int(os.getenv("VLLM_DISABLE_MLA", "0"))),
308-
309306
# Pipeline stage partition strategy
310307
"VLLM_PP_LAYER_PARTITION":
311308
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
@@ -512,6 +509,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
512509
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE":
513510
lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")),
514511

512+
# If set, vLLM will disable the MLA attention optimizations.
513+
"VLLM_MLA_DISABLE":
514+
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
515+
515516
# Flag that can control whether or not we perform matrix-absorption for MLA
516517
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
517518
# matrices reduces the runtime FLOPs needed to compute MLA but requires

0 commit comments

Comments
 (0)