Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Bug: Speculative Decoding "Segmentation fault (core dumped)" #10176

Open
AbdullahMPrograms opened this issue Nov 4, 2024 · 11 comments · Fixed by #10192 · May be fixed by #10185
Open

Bug: Speculative Decoding "Segmentation fault (core dumped)" #10176

AbdullahMPrograms opened this issue Nov 4, 2024 · 11 comments · Fixed by #10192 · May be fixed by #10185
Labels
bug Something isn't working low severity Used to report low severity bugs in llama.cpp (e.g. cosmetic issues, non critical UI glitches)

Comments

@AbdullahMPrograms
Copy link

What happened?

Hey all, I wanted to report a segmentation fault issue with llama-speculative. I have never once gotten this executable to work; I don't believe it is my command, as I have tried copy-pasting the speculative example commands as well.

Name and Version

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 3 CUDA devices:
Device 0: Tesla P40, compute capability 6.1, VMM: yes
Device 1: Tesla P40, compute capability 6.1, VMM: yes
Device 2: Tesla P40, compute capability 6.1, VMM: yes
version: 4031 (d5a409e)
built with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for x86_64-linux-gnu

What operating system are you seeing the problem on?

No response

Relevant log output

./LLM/llama.cpp/llama-speculative \
-m /home/ultimis/LLM/Models/mradermacher/Meta-Llama-3.1-70B-Instruct-i1-GGUF/Meta-Llama-3.1-70B-Instruct.i1-Q4_K_M.gguf \
-md /home/ultimis/LLM/Models/hugging-quants/Llama-3.2-1B-Instruct-Q8_0-GGUF/llama-3.2-1b-instruct-q8_0.gguf \
-p "// Quick-sort implementation in C (4 spaces indentation + detailed comments) and sample usage" \
-c 8000 -ngl 99 -ngld 30 --split-mode row --draft 16
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 3 CUDA devices:
  Device 0: Tesla P40, compute capability 6.1, VMM: yes
  Device 1: Tesla P40, compute capability 6.1, VMM: yes
  Device 2: Tesla P40, compute capability 6.1, VMM: yes
build: 4031 (d5a409e5) with cc (Ubuntu 13.2.0-23ubuntu4) 13.2.0 for x86_64-linux-gnu
llama_load_model_from_file: using device CUDA0 (Tesla P40) - 24286 MiB free
llama_load_model_from_file: using device CUDA1 (Tesla P40) - 24290 MiB free
llama_load_model_from_file: using device CUDA2 (Tesla P40) - 24290 MiB free
llama_model_loader: loaded meta data with 40 key-value pairs and 724 tensors from /home/ultimis/LLM/Models/mradermacher/Meta-Llama-3.1-70B-Instruct-i1-GGUF/Meta-Llama-3.1-70B-Instruct.i1-Q4_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Meta Llama 3.1 70B Instruct
llama_model_loader: - kv   3:                           general.finetune str              = Instruct
llama_model_loader: - kv   4:                           general.basename str              = Meta-Llama-3.1
llama_model_loader: - kv   5:                         general.size_label str              = 70B
llama_model_loader: - kv   6:                            general.license str              = llama3.1
llama_model_loader: - kv   7:                               general.tags arr[str,6]       = ["facebook", "meta", "pytorch", "llam...
llama_model_loader: - kv   8:                          general.languages arr[str,8]       = ["en", "de", "fr", "it", "pt", "hi", ...
llama_model_loader: - kv   9:                          llama.block_count u32              = 80
llama_model_loader: - kv  10:                       llama.context_length u32              = 131072
llama_model_loader: - kv  11:                     llama.embedding_length u32              = 8192
llama_model_loader: - kv  12:                  llama.feed_forward_length u32              = 28672
llama_model_loader: - kv  13:                 llama.attention.head_count u32              = 64
llama_model_loader: - kv  14:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv  15:                       llama.rope.freq_base f32              = 500000.000000
llama_model_loader: - kv  16:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  17:                          general.file_type u32              = 15
llama_model_loader: - kv  18:                           llama.vocab_size u32              = 128256
llama_model_loader: - kv  19:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  20:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  21:                         tokenizer.ggml.pre str              = llama-bpe
llama_model_loader: - kv  22:                      tokenizer.ggml.tokens arr[str,128256]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  23:                  tokenizer.ggml.token_type arr[i32,128256]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  24:                      tokenizer.ggml.merges arr[str,280147]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  25:                tokenizer.ggml.bos_token_id u32              = 128000
llama_model_loader: - kv  26:                tokenizer.ggml.eos_token_id u32              = 128009
llama_model_loader: - kv  27:                    tokenizer.chat_template str              = {% set loop_messages = messages %}{% ...
llama_model_loader: - kv  28:               general.quantization_version u32              = 2
llama_model_loader: - kv  29:                                general.url str              = https://huggingface.co/mradermacher/M...
llama_model_loader: - kv  30:              mradermacher.quantize_version str              = 2
llama_model_loader: - kv  31:                  mradermacher.quantized_by str              = mradermacher
llama_model_loader: - kv  32:                  mradermacher.quantized_at str              = 2024-07-29T10:58:40+02:00
llama_model_loader: - kv  33:                  mradermacher.quantized_on str              = db1
llama_model_loader: - kv  34:                         general.source.url str              = https://huggingface.co/meta-llama/Met...
llama_model_loader: - kv  35:                  mradermacher.convert_type str              = hf
llama_model_loader: - kv  36:                      quantize.imatrix.file str              = Meta-Llama-3.1-70B-Instruct-i1-GGUF/i...
llama_model_loader: - kv  37:                   quantize.imatrix.dataset str              = imatrix-training-full-3
llama_model_loader: - kv  38:             quantize.imatrix.entries_count i32              = 560
llama_model_loader: - kv  39:              quantize.imatrix.chunks_count i32              = 314
llama_model_loader: - type  f32:  162 tensors
llama_model_loader: - type q4_K:  441 tensors
llama_model_loader: - type q5_K:   40 tensors
llama_model_loader: - type q6_K:   81 tensors
llm_load_vocab: special tokens cache size = 256
llm_load_vocab: token to piece cache size = 0.7999 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 128256
llm_load_print_meta: n_merges         = 280147
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 131072
llm_load_print_meta: n_embd           = 8192
llm_load_print_meta: n_layer          = 80
llm_load_print_meta: n_head           = 64
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 8
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 28672
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 500000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 131072
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = 70B
llm_load_print_meta: model ftype      = Q4_K - Medium
llm_load_print_meta: model params     = 70.55 B
llm_load_print_meta: model size       = 39.59 GiB (4.82 BPW)
llm_load_print_meta: general.name     = Meta Llama 3.1 70B Instruct
llm_load_print_meta: BOS token        = 128000 '<|begin_of_text|>'
llm_load_print_meta: EOS token        = 128009 '<|eot_id|>'
llm_load_print_meta: EOT token        = 128009 '<|eot_id|>'
llm_load_print_meta: EOM token        = 128008 '<|eom_id|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOG token        = 128008 '<|eom_id|>'
llm_load_print_meta: EOG token        = 128009 '<|eot_id|>'
llm_load_print_meta: max token length = 256
llm_load_tensors: offloading 80 repeating layers to GPU
llm_load_tensors: offloading output layer to GPU
llm_load_tensors: offloaded 81/81 layers to GPU
llm_load_tensors: CPU_Mapped model buffer size =   563.62 MiB
llm_load_tensors:      CUDA0 model buffer size =     1.69 MiB
llm_load_tensors:      CUDA1 model buffer size =     1.69 MiB
llm_load_tensors:      CUDA2 model buffer size =     1.66 MiB
llm_load_tensors: CUDA0_Split model buffer size = 13302.19 MiB
llm_load_tensors: CUDA1_Split model buffer size = 12949.31 MiB
llm_load_tensors: CUDA2_Split model buffer size = 13722.95 MiB
...................................................................................................
llama_new_context_with_model: n_seq_max     = 1
llama_new_context_with_model: n_ctx         = 8000
llama_new_context_with_model: n_ctx_per_seq = 8000
llama_new_context_with_model: n_batch       = 2048
llama_new_context_with_model: n_ubatch      = 512
llama_new_context_with_model: flash_attn    = 0
llama_new_context_with_model: freq_base     = 500000.0
llama_new_context_with_model: freq_scale    = 1
llama_new_context_with_model: n_ctx_per_seq (8000) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_kv_cache_init:      CUDA0 KV buffer size =   843.75 MiB
llama_kv_cache_init:      CUDA1 KV buffer size =   843.75 MiB
llama_kv_cache_init:      CUDA2 KV buffer size =   812.50 MiB
llama_new_context_with_model: KV self size  = 2500.00 MiB, K (f16): 1250.00 MiB, V (f16): 1250.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.49 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  1079.63 MiB
llama_new_context_with_model:      CUDA1 compute buffer size =  1079.63 MiB
llama_new_context_with_model:      CUDA2 compute buffer size =  1079.63 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    31.63 MiB
llama_new_context_with_model: graph nodes  = 2566
llama_new_context_with_model: graph splits = 4
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
llama_load_model_from_file: using device CUDA0 (Tesla P40) - 8856 MiB free
llama_load_model_from_file: using device CUDA1 (Tesla P40) - 8460 MiB free
llama_load_model_from_file: using device CUDA2 (Tesla P40) - 8090 MiB free
llama_model_loader: loaded meta data with 30 key-value pairs and 147 tensors from /home/ultimis/LLM/Models/hugging-quants/Llama-3.2-1B-Instruct-Q8_0-GGUF/llama-3.2-1b-instruct-q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Llama 3.2 1B Instruct
llama_model_loader: - kv   3:                           general.finetune str              = Instruct
llama_model_loader: - kv   4:                           general.basename str              = Llama-3.2
llama_model_loader: - kv   5:                         general.size_label str              = 1B
llama_model_loader: - kv   6:                               general.tags arr[str,6]       = ["facebook", "meta", "pytorch", "llam...
llama_model_loader: - kv   7:                          general.languages arr[str,8]       = ["en", "de", "fr", "it", "pt", "hi", ...
llama_model_loader: - kv   8:                          llama.block_count u32              = 16
llama_model_loader: - kv   9:                       llama.context_length u32              = 131072
llama_model_loader: - kv  10:                     llama.embedding_length u32              = 2048
llama_model_loader: - kv  11:                  llama.feed_forward_length u32              = 8192
llama_model_loader: - kv  12:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv  13:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv  14:                       llama.rope.freq_base f32              = 500000.000000
llama_model_loader: - kv  15:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  16:                 llama.attention.key_length u32              = 64
llama_model_loader: - kv  17:               llama.attention.value_length u32              = 64
llama_model_loader: - kv  18:                          general.file_type u32              = 7
llama_model_loader: - kv  19:                           llama.vocab_size u32              = 128256
llama_model_loader: - kv  20:                 llama.rope.dimension_count u32              = 64
llama_model_loader: - kv  21:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  22:                         tokenizer.ggml.pre str              = llama-bpe
llama_model_loader: - kv  23:                      tokenizer.ggml.tokens arr[str,128256]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  24:                  tokenizer.ggml.token_type arr[i32,128256]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  25:                      tokenizer.ggml.merges arr[str,280147]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  26:                tokenizer.ggml.bos_token_id u32              = 128000
llama_model_loader: - kv  27:                tokenizer.ggml.eos_token_id u32              = 128009
llama_model_loader: - kv  28:                    tokenizer.chat_template str              = {% set loop_messages = messages %}{% ...
llama_model_loader: - kv  29:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   34 tensors
llama_model_loader: - type q8_0:  113 tensors
llm_load_vocab: special tokens cache size = 256
llm_load_vocab: token to piece cache size = 0.7999 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 128256
llm_load_print_meta: n_merges         = 280147
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 131072
llm_load_print_meta: n_embd           = 2048
llm_load_print_meta: n_layer          = 16
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_rot            = 64
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 64
llm_load_print_meta: n_embd_head_v    = 64
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 512
llm_load_print_meta: n_embd_v_gqa     = 512
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 8192
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 500000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 131072
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = 1B
llm_load_print_meta: model ftype      = Q8_0
llm_load_print_meta: model params     = 1.24 B
llm_load_print_meta: model size       = 1.22 GiB (8.50 BPW)
llm_load_print_meta: general.name     = Llama 3.2 1B Instruct
llm_load_print_meta: BOS token        = 128000 '<|begin_of_text|>'
llm_load_print_meta: EOS token        = 128009 '<|eot_id|>'
llm_load_print_meta: EOT token        = 128009 '<|eot_id|>'
llm_load_print_meta: EOM token        = 128008 '<|eom_id|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOG token        = 128008 '<|eom_id|>'
llm_load_print_meta: EOG token        = 128009 '<|eot_id|>'
llm_load_print_meta: max token length = 256
llm_load_tensors: offloading 16 repeating layers to GPU
llm_load_tensors: offloading output layer to GPU
llm_load_tensors: offloaded 17/17 layers to GPU
llm_load_tensors: CPU_Mapped model buffer size =   266.16 MiB
llm_load_tensors:      CUDA0 model buffer size =     0.09 MiB
llm_load_tensors:      CUDA1 model buffer size =     0.09 MiB
llm_load_tensors:      CUDA2 model buffer size =     0.07 MiB
llm_load_tensors: CUDA0_Split model buffer size =   369.75 MiB
llm_load_tensors: CUDA1_Split model buffer size =   369.75 MiB
llm_load_tensors: CUDA2_Split model buffer size =   512.66 MiB
..............................................................
llama_new_context_with_model: n_seq_max     = 1
llama_new_context_with_model: n_ctx         = 8000
llama_new_context_with_model: n_ctx_per_seq = 8000
llama_new_context_with_model: n_batch       = 2048
llama_new_context_with_model: n_ubatch      = 512
llama_new_context_with_model: flash_attn    = 0
llama_new_context_with_model: freq_base     = 500000.0
llama_new_context_with_model: freq_scale    = 1
llama_new_context_with_model: n_ctx_per_seq (8000) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_kv_cache_init:      CUDA0 KV buffer size =    93.75 MiB
llama_kv_cache_init:      CUDA1 KV buffer size =    93.75 MiB
llama_kv_cache_init:      CUDA2 KV buffer size =    62.50 MiB
llama_new_context_with_model: KV self size  =  250.00 MiB, K (f16):  125.00 MiB, V (f16):  125.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.49 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   531.63 MiB
llama_new_context_with_model:      CUDA1 compute buffer size =   531.63 MiB
llama_new_context_with_model:      CUDA2 compute buffer size =   531.63 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    19.63 MiB
llama_new_context_with_model: graph nodes  = 518
llama_new_context_with_model: graph splits = 4
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)


<|begin_of_text|>// Quick-sort implementation in C (4 spaces indentation + detailed comments) and sample usage.

Segmentation fault (core dumped)
@AbdullahMPrograms AbdullahMPrograms added bug-unconfirmed critical severity Used to report critical severity bugs in llama.cpp (e.g. Crashing, Corrupted, Dataloss) labels Nov 4, 2024
@slaren
Copy link
Collaborator

slaren commented Nov 4, 2024

Looks like a crash in the DRY sampler when it is cloned due to model being NULL. Removing the DRY sampler from the sequence with --sampling-seq should allow you to use the speculative example until this is fixed. cc @wwoodsTM

=================================================================
==2432259==ERROR: AddressSanitizer: SEGV on unknown address 0x000000000038 (pc 0x5609ca4f6e80 bp 0x7fff9a863600 sp 0x7fff9a8635f0 T0)
==2432259==The signal is caused by a READ memory access.
==2432259==Hint: address points to the zero page.
    #0 0x5609ca4f6e80 in llama_n_ctx_train src/llama.cpp:19713
    #1 0x5609ca50b3ae in llama_sampler_init_dry src/llama.cpp:21871
    #2 0x5609ca73b0cb in llama_sampler_dry_clone src/llama-sampling.cpp:1880
    #3 0x5609ca72dc69 in llama_sampler_clone src/llama-sampling.cpp:233
    #4 0x5609ca72ef9b in llama_sampler_chain_clone src/llama-sampling.cpp:333
    #5 0x5609ca72dc69 in llama_sampler_clone src/llama-sampling.cpp:233
    #6 0x5609ca9aa4a3 in common_sampler_clone(common_sampler*) common/sampling.cpp:259
    #7 0x5609caa36d70 in main examples/speculative/speculative.cpp:454
    #8 0x7fc44782814f in __libc_start_call_main ../sysdeps/nptl/libc_start_call_main.h:58
    #9 0x7fc447828208 in __libc_start_main_impl ../csu/libc-start.c:360
    #10 0x5609c97a9084 in _start (/home/diego/code/llama.cpp/llama-speculative+0x18e084) (BuildId: 8bdd8edabb9d68458f2c0ed5c8c4ad2e5801a83a)

AddressSanitizer can not provide additional info.
SUMMARY: AddressSanitizer: SEGV src/llama.cpp:19713 in llama_n_ctx_train

@slaren slaren added bug Something isn't working low severity Used to report low severity bugs in llama.cpp (e.g. cosmetic issues, non critical UI glitches) and removed bug-unconfirmed critical severity Used to report critical severity bugs in llama.cpp (e.g. Crashing, Corrupted, Dataloss) labels Nov 4, 2024
@slaren
Copy link
Collaborator

slaren commented Nov 4, 2024

While running this with address sanitizer, it also detects a buffer overflow after generating tokens for a while (unrelated to the DRY issue):

==2437758==ERROR: AddressSanitizer: heap-buffer-overflow on address 0x60300241c108 at pc 0x5621951d6375 bp 0x7ffc15d82830 sp 0x7ffc15d82820
READ of size 4 at 0x60300241c108 thread T0
    #0 0x5621951d6374 in main examples/speculative/speculative.cpp:271
    #1 0x7fefa962814f in __libc_start_call_main ../sysdeps/nptl/libc_start_call_main.h:58
    #2 0x7fefa9628208 in __libc_start_main_impl ../csu/libc-start.c:360
    #3 0x562193f4b084 in _start (/home/diego/code/llama.cpp/llama-speculative+0x18e084) (BuildId: 8bdd8edabb9d68458f2c0ed5c8c4ad2e5801a83a)

0x60300241c108 is located 0 bytes after 24-byte region [0x60300241c0f0,0x60300241c108)
allocated by thread T0 here:
    #0 0x7fefd02dfba8 in operator new(unsigned long) ../../../../src/libsanitizer/asan/asan_new_delete.cpp:95
    #1 0x562194f07cfa in std::__new_allocator<llama_token_data>::allocate(unsigned long, void const*) /usr/include/c++/13/bits/new_allocator.h:147
    #2 0x562194ef7fa2 in std::allocator_traits<std::allocator<llama_token_data> >::allocate(std::allocator<llama_token_data>&, unsigned long) /usr/include/c++/13/bits/alloc_traits.h:482
    #3 0x562194ef7fa2 in std::_Vector_base<llama_token_data, std::allocator<llama_token_data> >::_M_allocate(unsigned long) /usr/include/c++/13/bits/stl_vector.h:378
    #4 0x5621951ecfeb in void std::vector<llama_token_data, std::allocator<llama_token_data> >::_M_range_initialize<llama_token_data*>(llama_token_data*, llama_token_data*, std::forward_iterator_tag) /usr/include/c++/13/bits/stl_vector.h:1689
    #5 0x5621951ea7b6 in std::vector<llama_token_data, std::allocator<llama_token_data> >::vector<llama_token_data*, void>(llama_token_data*, llama_token_data*, std::allocator<llama_token_data> const&) /usr/include/c++/13/bits/stl_vector.h:708
    #6 0x5621951daa7d in main examples/speculative/speculative.cpp:546
    #7 0x7fefa962814f in __libc_start_call_main ../sysdeps/nptl/libc_start_call_main.h:58

SUMMARY: AddressSanitizer: heap-buffer-overflow examples/speculative/speculative.cpp:271 in main

cc @ggerganov

@AbdullahMPrograms
Copy link
Author

AbdullahMPrograms commented Nov 4, 2024

Interestingly, I seem to have run into a different issue with the --sampling-seq modifier when using speculative decoding with Qwen 2.5, Llama3.1 seems to be working just fine:
speculativelog.txt

@slaren
Copy link
Collaborator

slaren commented Nov 4, 2024

Looks like an issue with -sm row, but I cannot reproduce it.

@AbdullahMPrograms
Copy link
Author

It seems it only occurs when using Qwen2.5-0.5B as the draft model, 1.5B and onwards operate as expected

@slaren
Copy link
Collaborator

slaren commented Nov 5, 2024

I could reproduce it now. I think this is because this model is so small that the tensor does not have enough rows, and some devices end with 0 rows, which causes the event to not be created. It can be reproduced with llama-cli, it's not specific to the speculative example. cc @JohannesGaessler

@AbdullahMPrograms
Copy link
Author

I see, that does make sense

@wwoodsTM
Copy link
Contributor

wwoodsTM commented Nov 5, 2024

Looks like a crash in the DRY sampler when it is cloned due to model being NULL. Removing the DRY sampler from the sequence with --sampling-seq should allow you to use the speculative example until this is fixed. cc @wwoodsTM

Thank you for the heads up, I will try to get this fixed ASAP.

@wwoodsTM
Copy link
Contributor

wwoodsTM commented Nov 6, 2024

@slaren PR #10192 should fix this cloning issue. You pinpointing the issue so precisely was very helpful; thanks.

@ggerganov
Copy link
Owner

While running this with address sanitizer, it also detects a buffer overflow after generating tokens for a while (unrelated to the DRY issue):

@slaren Do you have a repro? I'm running a few tests here with --sampling-seq k and cannot reproduce the sanitizer issue.

@slaren
Copy link
Collaborator

slaren commented Nov 6, 2024

I can reproduce it reliably with this command line:

./llama-speculative \
    -m models/Meta-Llama-3.1-8B-Instruct/ggml-model-Q4_0.gguf \
    -md models/Llama-3.2-1B-Instruct-IQ3_M.gguf \
    -p "// Quick-sort implementation in C (4 spaces indentation + detailed comments) and sample usage" \
    -c 8000 -ngl 99 -ngld 30 --draft 16 --sampling-seq k,m,t

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something isn't working low severity Used to report low severity bugs in llama.cpp (e.g. cosmetic issues, non critical UI glitches)
Projects
None yet
4 participants