From 042c3419fad1a89c32a27abe8089af6de960bfce Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Wed, 12 Feb 2025 09:06:13 -0800 Subject: [PATCH] Introduce VLLM_CUDART_SO_PATH to allow users specify the .so path (#12998) Signed-off-by: Lu Fang --- .../device_communicators/cuda_wrapper.py | 32 ++++++++++++++++++- vllm/envs.py | 6 ++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py index 010caf7ebac97..bc2cfbf321875 100644 --- a/vllm/distributed/device_communicators/cuda_wrapper.py +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -5,12 +5,14 @@ """ import ctypes +import glob from dataclasses import dataclass from typing import Any, Dict, List, Optional # this line makes it possible to directly load `libcudart.so` using `ctypes` import torch # noqa +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) @@ -60,6 +62,29 @@ def find_loaded_library(lib_name) -> Optional[str]: return path +def get_cudart_lib_path_from_env() -> Optional[str]: + """ + In some system, find_loaded_library() may not work. So we allow users to + specify the path through environment variable VLLM_CUDART_SO_PATH. + """ + cudart_so_env = envs.VLLM_CUDART_SO_PATH + if cudart_so_env is not None: + cudart_paths = [ + cudart_so_env, + ] + for path in cudart_paths: + file_paths = glob.glob(path) + if len(file_paths) > 0: + logger.info( + "Found cudart library at %s through env var" + "VLLM_CUDART_SO_PATH=%s", + file_paths[0], + cudart_so_env, + ) + return file_paths[0] + return None + + class CudaRTLibrary: exported_functions = [ # ​cudaError_t cudaSetDevice ( int device ) @@ -105,8 +130,13 @@ class CudaRTLibrary: def __init__(self, so_file: Optional[str] = None): if so_file is None: so_file = find_loaded_library("libcudart") + if so_file is None: + so_file = get_cudart_lib_path_from_env() assert so_file is not None, \ - "libcudart is not loaded in the current process" + ( + "libcudart is not loaded in the current process, " + "try setting VLLM_CUDART_SO_PATH" + ) if so_file not in CudaRTLibrary.path_to_library_cache: lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib diff --git a/vllm/envs.py b/vllm/envs.py index 745b068b7a458..d99c794e69e6c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -87,6 +87,7 @@ VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" + VLLM_CUDART_SO_PATH: Optional[str] = None def get_default_cache_root(): @@ -572,6 +573,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # models the alignment is already naturally aligned to 256 bytes. "VLLM_CUDA_MEM_ALIGN_KV_CACHE": lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))), + + # In some system, find_loaded_library() may not work. So we allow users to + # specify the path through environment variable VLLM_CUDART_SO_PATH. + "VLLM_CUDART_SO_PATH": + lambda: os.getenv("VLLM_CUDART_SO_PATH", None), } # end-env-vars-definition