Skip to content

Commit db2c583

Browse files
filter compressed tensor models better
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent e144da8 commit db2c583

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

vllm/config.py

+21
Original file line numberDiff line numberDiff line change
@@ -971,8 +971,29 @@ def is_cross_encoder(self) -> bool:
971971
def use_mla(self) -> bool:
972972
if self.quantization is not None and self.quantization not in [\
973973
"fp8", "compressed-tensors"]:
974+
logger.warning(
975+
"MLA is not supported with %s quantization. "
976+
"Disabling MLA.", self.quantization)
974977
return False
975978

979+
# If using a "compressed-tensors" checkpoint, check that all groups
980+
# have fp8 for both weights and activations.
981+
if self.quantization == "compressed-tensors":
982+
quant_config = self._parse_quant_hf_config()
983+
for group_name, group_cfg in quant_config.get("config_groups",
984+
{}).items():
985+
input_act_type = group_cfg.get("input_activations", {})\
986+
.get("type", "unknown").lower()
987+
weights_type = group_cfg.get("weights", {})\
988+
.get("type", "unknown").lower()
989+
if input_act_type != "fp8" or weights_type != "fp8":
990+
logger.warning(
991+
"compressed-tensors MLA support requires fp8 "
992+
"activations and weights in group '%s', but got "
993+
"activations type '%s' and weights type '%s'.",
994+
group_name, input_act_type, weights_type)
995+
return False
996+
976997
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
977998
return use_mla
978999

vllm/model_executor/model_loader/loader.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,8 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
644644
hasattr(module, "process_weights_after_loading"):
645645
# When attention modules need to process weights after
646646
# currently only used by MLA
647-
module.process_weights_after_loading(model_config.dtype)
647+
module.process_weights_after_loading(
648+
model_config.dtype)
648649
rank = get_tensor_model_parallel_rank()
649650
pattern = os.path.join(
650651
local_model_path,

0 commit comments

Comments
 (0)