diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index db55a31476fed..dfe93be462184 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -9,6 +9,7 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod class Attention(nn.Module): @@ -56,15 +57,19 @@ def __init__( quant_method = quant_config.get_quant_method( self) if quant_config else None if quant_method is not None: - if self.kv_cache_dtype == "fp8_e5m2": - raise ValueError("fp8_e5m2 kv-cache is not supported with " - "fp8 checkpoints.") - # When FP8 quantization is enabled, we make a parameter - # "kv_scale" so that it can be loaded from FP8 checkpoint. - # The kv_scale will then be converted back - # to self._kv_scale in a native float32 value after weight loading. - self.quant_method = quant_method - self.quant_method.create_weights(self) + assert isinstance(quant_method, Fp8KVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if "fp8" in self.kv_cache_dtype: + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # When FP8 quantization is enabled, we make a parameter + # "kv_scale" so that it can be loaded from FP8 checkpoint. + # The kv_scale will then be converted back to self._kv_scale + # in a native float32 value after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) # During model initialization, the default dtype is set as the model # weight and activation dtype.