Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Bug]: I try to use vllm==0.6.5 for GLM4-9b-chat but error "/usr/bin/ld: cannot find -lcuda" #11643

Open
1 task done
Jimmy-L99 opened this issue Dec 31, 2024 · 8 comments
Open
1 task done
Labels
bug Something isn't working

Comments

@Jimmy-L99
Copy link

Your current environment

The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.12.0 | packaged by Anaconda, Inc. | (main, Oct  2 2023, 17:29:18) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-200-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.4.48
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe

Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 48 bits virtual
CPU(s):                             64
On-line CPU(s) list:                0-63
Thread(s) per core:                 2
Core(s) per socket:                 16
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              85
Model name:                         Intel(R) Xeon(R) Gold 6226R CPU @ 2.90GHz
Stepping:                           7
CPU MHz:                            1263.722
CPU max MHz:                        3900.0000
CPU min MHz:                        1200.0000
BogoMIPS:                           5800.00
Virtualization:                     VT-x
L1d cache:                          1 MiB
L1i cache:                          1 MiB
L2 cache:                           32 MiB
L3 cache:                           44 MiB
NUMA node0 CPU(s):                  0-15,32-47
NUMA node1 CPU(s):                  16-31,48-63
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit:        KVM: Mitigation: Split huge pages
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Mitigation; Enhanced IBRS
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; TSX disabled
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.0
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.47.1
[pip3] triton==3.1.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
[conda] nvidia-ml-py              12.560.30                pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
[conda] pyzmq                     26.2.0                   pypi_0    pypi
[conda] torch                     2.5.1                    pypi_0    pypi
[conda] torchvision               0.20.1                   pypi_0    pypi
[conda] transformers              4.46.2                   pypi_0    pypi
[conda] triton                    3.1.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.5
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NODE    0-15,32-47      0               N/A
GPU1    NODE     X      0-15,32-47      0               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

LD_LIBRARY_PATH=/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/cv2/../../lib64:/usr/local/cuda-11.4/targets/x86_64-linux/lib/stubs:/usr/local/cuda-11.4/lib64:/usr/local/cuda-11.4/lib64
CUDA_MODULE_LOADING=LAZY

Model Input Dumps

(glm4-9b-chat-128k-vLLM0_6_5) root@node1:~/ljm/ChatGLM4/GLM-4/api_server_vLLM# python /root/ljm/ChatGLM4/GLM-4/api_server_vLLM/openai_api_server_vLLM_glm4-chat.py
INFO 12-31 10:07:32 config.py:478] This model supports multiple tasks: {'reward', 'embed', 'classify', 'score', 'generate'}. Defaulting to 'generate'.
WARNING 12-31 10:07:32 arg_utils.py:1096] The model has a long context length (65528). This may cause OOM errors during the initial memory profiling phase, or result in low performance due to small KV cache space. Consider setting --max-model-len to a smaller value.
INFO 12-31 10:07:32 llm_engine.py:249] Initializing an LLM engine (v0.6.5) with config: model='/root/ljm/models/glm-4-9b-chat', speculative_config=None, tokenizer='/root/ljm/models/glm-4-9b-chat', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=65528, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/root/ljm/models/glm-4-9b-chat, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=True, mm_cache_preprocessor=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"candidate_compile_sizes":[],"compile_sizes":[],"capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=False, 
WARNING 12-31 10:07:33 tokenizer.py:174] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
INFO 12-31 10:07:34 selector.py:120] Using Flash Attention backend.
INFO 12-31 10:07:34 model_runner.py:1092] Starting to load model /root/ljm/models/glm-4-9b-chat...
Loading safetensors checkpoint shards:   0% Completed | 0/10 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  10% Completed | 1/10 [00:00<00:04,  1.94it/s]
Loading safetensors checkpoint shards:  20% Completed | 2/10 [00:01<00:04,  1.90it/s]
Loading safetensors checkpoint shards:  30% Completed | 3/10 [00:01<00:03,  1.83it/s]
Loading safetensors checkpoint shards:  40% Completed | 4/10 [00:02<00:03,  1.80it/s]
Loading safetensors checkpoint shards:  50% Completed | 5/10 [00:02<00:02,  1.83it/s]
Loading safetensors checkpoint shards:  60% Completed | 6/10 [00:03<00:02,  1.82it/s]
Loading safetensors checkpoint shards:  70% Completed | 7/10 [00:03<00:01,  1.92it/s]
Loading safetensors checkpoint shards:  80% Completed | 8/10 [00:04<00:01,  1.94it/s]
Loading safetensors checkpoint shards:  90% Completed | 9/10 [00:04<00:00,  1.91it/s]
Loading safetensors checkpoint shards: 100% Completed | 10/10 [00:05<00:00,  1.85it/s]
Loading safetensors checkpoint shards: 100% Completed | 10/10 [00:05<00:00,  1.87it/s]

INFO 12-31 10:07:40 model_runner.py:1097] Loading model weights took 17.5635 GB
INFO 12-31 10:07:40 punica_selector.py:11] Using PunicaWrapperGPU.
/usr/bin/ld: cannot find -lcuda
collect2: error: ld returned 1 exit status
INFO 12-31 10:07:41 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20241231-100741.pkl...
INFO 12-31 10:07:41 model_runner_base.py:149] Completed writing input of failed execution to /tmp/err_execute_model_input_20241231-100741.pkl.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/worker/model_runner_base.py", line 116, in _wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1683, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:                                     ^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/model_executor/models/chatglm.py", line 639, in forward
[rank0]:     hidden_states = self.transformer(input_ids, positions, kv_caches,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/model_executor/models/chatglm.py", line 595, in forward
[rank0]:     hidden_states = self.encoder(
[rank0]:                     ^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/model_executor/models/chatglm.py", line 480, in forward
[rank0]:     hidden_states = layer(
[rank0]:                     ^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/model_executor/models/chatglm.py", line 408, in forward
[rank0]:     attention_output = self.self_attention(
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/model_executor/models/chatglm.py", line 300, in forward
[rank0]:     qkv, _ = self.query_key_value(hidden_states)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/lora/layers.py", line 512, in forward
[rank0]:     output_parallel = self.apply(input_, bias)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/lora/layers.py", line 392, in apply
[rank0]:     self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/lora/punica_wrapper/punica_gpu.py", line 308, in add_lora_linear
[rank0]:     self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/lora/punica_wrapper/punica_gpu.py", line 187, in add_shrink
[rank0]:     self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/lora/punica_wrapper/punica_gpu.py", line 160, in _apply_shrink
[rank0]:     shrink_fun(y, x, w_t_all, scale)
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/lora/punica_wrapper/punica_gpu.py", line 48, in _shrink_prefill
[rank0]:     sgmv_shrink(
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/lora/ops/sgmv_shrink.py", line 169, in _sgmv_shrink
[rank0]:     _sgmv_shrink_kernel[grid](
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/triton/runtime/jit.py", line 345, in <lambda>
[rank0]:     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
[rank0]:                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/triton/runtime/jit.py", line 607, in run
[rank0]:     device = driver.active.get_current_device()
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/triton/runtime/driver.py", line 23, in __getattr__
[rank0]:     self._initialize_obj()
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/triton/runtime/driver.py", line 20, in _initialize_obj
[rank0]:     self._obj = self._init_fn()
[rank0]:                 ^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/triton/runtime/driver.py", line 9, in _create_driver
[rank0]:     return actives[0]()
[rank0]:            ^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 371, in __init__
[rank0]:     self.utils = CudaUtils()  # TODO: make static
[rank0]:                  ^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 80, in __init__
[rank0]:     mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils")
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 57, in compile_module_from_src
[rank0]:     so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/triton/runtime/build.py", line 48, in _build
[rank0]:     ret = subprocess.check_call(cc_cmd)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/subprocess.py", line 413, in check_call
[rank0]:     raise CalledProcessError(retcode, cmd)
[rank0]: subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmpmsoch_qa/main.c', '-O3', '-shared', '-fPIC', '-o', '/tmp/tmpmsoch_qa/cuda_utils.cpython-312-x86_64-linux-gnu.so', '-lcuda', '-L/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/triton/backends/nvidia/lib', '-L/lib/x86_64-linux-gnu', '-I/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/triton/backends/nvidia/include', '-I/tmp/tmpmsoch_qa', '-I/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/include/python3.12']' returned non-zero exit status 1.

[rank0]: During handling of the above exception, another exception occurred:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/ljm/ChatGLM4/GLM-4/api_server_vLLM/openai_api_server_vLLM_glm4-chat.py", line 691, in <module>
[rank0]:     engine = AsyncLLMEngine.from_engine_args(engine_args)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/engine/async_llm_engine.py", line 707, in from_engine_args
[rank0]:     engine = cls(
[rank0]:              ^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/engine/async_llm_engine.py", line 594, in __init__
[rank0]:     self.engine = self._engine_class(*args, **kwargs)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/engine/async_llm_engine.py", line 267, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 291, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 431, in _initialize_kv_caches
[rank0]:     self.model_executor.determine_num_available_blocks())
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/executor/gpu_executor.py", line 68, in determine_num_available_blocks
[rank0]:     return self.driver_worker.determine_num_available_blocks()
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/worker/worker.py", line 202, in determine_num_available_blocks
[rank0]:     self.model_runner.profile_run()
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1329, in profile_run
[rank0]:     self.execute_model(model_input, kv_caches, intermediate_tensors)
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/anaconda3/envs/glm4-9b-chat-128k-vLLM0_6_5/lib/python3.12/site-packages/vllm/worker/model_runner_base.py", line 152, in _wrapper
[rank0]:     raise type(err)(
[rank0]:           ^^^^^^^^^^
[rank0]: TypeError: CalledProcessError.__init__() missing 1 required positional argument: 'cmd'
[rank0]:[W1231 10:07:41.345588356 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())

🐛 Describe the bug

I try to use glm4-9b-chat with vllm==0.6.5, but the error say /usr/bin/ld: cannot find -lcuda, when I use vllm==0.5.5, it can works.
So anybody know what's going on here, what's the problem?

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@Jimmy-L99 Jimmy-L99 added the bug Something isn't working label Dec 31, 2024
@jeejeelee
Copy link
Collaborator

Maybe the cause is your cuda runtime is old:

CUDA runtime version: 11.4.48

@Jimmy-L99
Copy link
Author

Maybe the cause is your cuda runtime is old:

CUDA runtime version: 11.4.48

Thanks, by the way, could you please tell me which version of the CUDA runtime is required?

@jeejeelee
Copy link
Collaborator

I am using version 12.4

@gargnipungarg
Copy link

Same issue with 12.4 for me

@balachandarsv
Copy link

balachandarsv commented Jan 22, 2025

The issue still exists for latest nightly builds. Is someone looking into this? Using Cuda 12.4

@stefanobranco
Copy link

Same here. I updated our DGX because I figured it was a cuda-issue as well, but that doesn't seem to help. (though I was on 12.4 already before updating). I'm using the docker containers, and it also happens for the newer 0.6.6.

@stefanobranco
Copy link

stefanobranco commented Jan 31, 2025

I've looked into this a bit more. We're using the nvidia-gpu-operator to deploy vllm via helm charts. I'm also no expert in this area at all, but maybe this helps someone smarter than me. I think there's a few separate issues going one. One of them should be fixed in the next version (see #12505). I think for that issue the symlink fix should work, but I think if you've been having the issue before 0.7.0 already, it's something else.

I've connected to the containers manually, and calling ld -lcuda --verbose show that even on 0.6.4.post1 (the last container that works for me) I get /usr/bin/ld: cannot find -lcuda, I just hadn't caused any problems. Looking at the containers a bit closer I believe this could instead be connected to the nvidia-gpu-operator instead:

0.6.4.post1:

root@dgx-inference:/vllm-workspace# find /usr -name "libcuda.so*"
/usr/local/cuda-12.4/compat/libcuda.so.1
/usr/local/cuda-12.4/compat/libcuda.so.550.54.15
/usr/local/cuda-12.4/compat/libcuda.so
/usr/lib/x86_64-linux-gnu/libcuda.so.1
/usr/lib/x86_64-linux-gnu/libcuda.so.550.127.08

0.7.0:

root@dgx-inference:/vllm-workspace# find /usr -name "libcuda.so*"
/usr/local/cuda-12.1/compat/libcuda.so.1
/usr/local/cuda-12.1/compat/libcuda.so
/usr/local/cuda-12.1/compat/libcuda.so.530.30.02
/usr/lib/x86_64-linux-gnu/libcuda.so.1
/usr/lib/x86_64-linux-gnu/libcuda.so.550.127.08

I'm not sure why there's a downgrade from 12.4 to 12.1 here, these run on the same host system, the only difference here is the vllm container.

Maybe more interestingly though is the missing libcuda.so in /usr/lib. Adding a link there as described in #12505 seems to fix the issue. On the host system everything seems correct:

/usr/local/cuda-12.8/targets/x86_64-linux/lib/stubs/libcuda.so
/usr/lib/x86_64-linux-gnu/libcuda.so.1
/usr/lib/x86_64-linux-gnu/libcuda.so
/usr/lib/x86_64-linux-gnu/libcuda.so.550.127.08

This makes me think this is more likely to be an issue with the nvidia-gpu-operator than with vllm itself, but once again I'm not an expert in this at all, this is just what I noticed.

@balachandarsv
Copy link

#12312 (comment)

I had similar issues, may be this could help.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants