From 12d058349113b814072b758f4cda266f30dd4910 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 24 Jan 2025 16:05:57 +0800 Subject: [PATCH 1/3] add xpu bf16 support Signed-off-by: Kunshang Ji --- .../installation/gpu/xpu.inc.md | 2 +- vllm/platforms/xpu.py | 24 ++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/source/getting_started/installation/gpu/xpu.inc.md b/docs/source/getting_started/installation/gpu/xpu.inc.md index 4116826789e5c..ad59180b96053 100644 --- a/docs/source/getting_started/installation/gpu/xpu.inc.md +++ b/docs/source/getting_started/installation/gpu/xpu.inc.md @@ -36,7 +36,7 @@ VLLM_TARGET_DEVICE=xpu python setup.py install :::{note} - FP16 is the default data type in the current XPU backend. The BF16 data - type will be supported in the future. + type is experimental supported now. ::: ## Set up using Docker diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index a5ca77f57cf47..852a18d2797fd 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -66,9 +66,15 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # check and update model config model_config = vllm_config.model_config if model_config.dtype == torch.bfloat16: - logger.warning( - "bfloat16 is not fully supported on XPU, casting to float16.") - model_config.dtype = torch.float16 + bf16_supported = cls.device_support_bf16() + device_name = cls.get_device_name() + if not bf16_supported: + logger.warning( + "bfloat16 is only supported on Intel Data Center GPU, " + "Intel Arc GPU is not supported yet. Your device is %s," + "which is not supported. will fallback to float16", + device_name) + model_config.dtype = torch.float16 if not model_config.enforce_eager: logger.warning( "CUDA graph is not supported on XPU, fallback to the eager " @@ -116,3 +122,15 @@ def get_current_memory_usage(cls, ) -> float: torch.xpu.reset_peak_memory_stats(device) return torch.xpu.max_memory_allocated(device) + + @classmethod + def device_support_bf16(cls) -> bool: + device_name = cls.get_device_name().lower() + if device_name.count("arc") > 0: + return False + elif device_name.count("data center gpu") > 0: + return True + else: + logger.warning("Unknown device name %s, always use float16", + device_name) + return False From 769fa21a3b9e9f293f9ea224949827428f61fde2 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 24 Jan 2025 16:15:11 +0800 Subject: [PATCH 2/3] add doc Signed-off-by: Kunshang Ji --- docs/source/getting_started/installation/gpu/xpu.inc.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/getting_started/installation/gpu/xpu.inc.md b/docs/source/getting_started/installation/gpu/xpu.inc.md index ad59180b96053..ef02d9a078a1b 100644 --- a/docs/source/getting_started/installation/gpu/xpu.inc.md +++ b/docs/source/getting_started/installation/gpu/xpu.inc.md @@ -36,7 +36,7 @@ VLLM_TARGET_DEVICE=xpu python setup.py install :::{note} - FP16 is the default data type in the current XPU backend. The BF16 data - type is experimental supported now. + type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet. ::: ## Set up using Docker From 21d70404db158066a91933ac2b549f0b6717e892 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 24 Jan 2025 16:16:35 +0800 Subject: [PATCH 3/3] minor Signed-off-by: Kunshang Ji --- vllm/platforms/xpu.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 852a18d2797fd..039cdd5adc9af 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -67,13 +67,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: model_config = vllm_config.model_config if model_config.dtype == torch.bfloat16: bf16_supported = cls.device_support_bf16() - device_name = cls.get_device_name() if not bf16_supported: logger.warning( "bfloat16 is only supported on Intel Data Center GPU, " "Intel Arc GPU is not supported yet. Your device is %s," "which is not supported. will fallback to float16", - device_name) + cls.get_device_name()) model_config.dtype = torch.float16 if not model_config.enforce_eager: logger.warning(