@@ -971,8 +971,29 @@ def is_cross_encoder(self) -> bool:
971
971
def use_mla (self ) -> bool :
972
972
if self .quantization is not None and self .quantization not in [\
973
973
"fp8" , "compressed-tensors" ]:
974
+ logger .warning (
975
+ "MLA is not supported with %s quantization. "
976
+ "Disabling MLA." , self .quantization )
974
977
return False
975
978
979
+ # If using a "compressed-tensors" checkpoint, check that all groups
980
+ # have fp8 for both weights and activations.
981
+ if self .quantization == "compressed-tensors" :
982
+ quant_config = self ._parse_quant_hf_config ()
983
+ for group_name , group_cfg in quant_config .get ("config_groups" ,
984
+ {}).items ():
985
+ input_act_type = group_cfg .get ("input_activations" , {})\
986
+ .get ("type" , "unknown" ).lower ()
987
+ weights_type = group_cfg .get ("weights" , {})\
988
+ .get ("type" , "unknown" ).lower ()
989
+ if input_act_type != "fp8" or weights_type != "fp8" :
990
+ logger .warning (
991
+ "compressed-tensors MLA support requires fp8 "
992
+ "activations and weights in group '%s', but got "
993
+ "activations type '%s' and weights type '%s'." ,
994
+ group_name , input_act_type , weights_type )
995
+ return False
996
+
976
997
use_mla = (self .is_deepseek_mla and not envs .VLLM_MLA_DISABLE )
977
998
return use_mla
978
999
0 commit comments