Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

vulkan: Use fp16 for the flash attention P*V multiplication #12783

Merged
merged 1 commit into from
Apr 9, 2025

Conversation

jeffbolznv
Copy link
Collaborator

This is consistent with the ggml-cuda behavior and the mul_mat fallback.

cuda: https://github.com/ggml-org/llama.cpp/blob/master/ggml/src/ggml-cuda/fattn-mma-f16.cuh#L13
fallback: https://github.com/ggml-org/llama.cpp/blob/master/src/llama-graph.cpp#L1170

before:

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 4096,8192,16384 -n 0 -fa 1 --repetitions 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |        pp4096 |       3096.20 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |        pp8192 |       2809.96 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |       pp16384 |       2381.60 ± 0.00 |
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr=4,kv=4096,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 47696 runs -    21.81 us/run -  33.55 MFLOP/run -   1.54 TFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr=4,kv=4096,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):               32802 runs -    30.57 us/run -  67.11 MFLOP/run -   2.20 TFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr=4,kv=8192,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 31311 runs -    33.14 us/run -  67.11 MFLOP/run -   2.02 TFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr=4,kv=8192,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):               20888 runs -    47.96 us/run - 134.22 MFLOP/run -   2.80 TFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr=4,kv=16384,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                18650 runs -    55.29 us/run - 134.22 MFLOP/run -   2.43 TFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr=4,kv=16384,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):               6714 runs -   152.30 us/run - 268.44 MFLOP/run -   1.76 TFLOPS

after:

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 4096,8192,16384 -n 0 -fa 1 --repetitions 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |        pp4096 |       3119.81 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |        pp8192 |       2877.54 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |       pp16384 |       2487.28 ± 0.00 |


  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr=4,kv=4096,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 47696 runs -    22.11 us/run -  33.55 MFLOP/run -   1.52 TFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr=4,kv=4096,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):               37275 runs -    27.34 us/run -  67.11 MFLOP/run -   2.45 TFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr=4,kv=8192,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 35784 runs -    29.07 us/run -  67.11 MFLOP/run -   2.31 TFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr=4,kv=8192,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):               24618 runs -    41.80 us/run - 134.22 MFLOP/run -   3.21 TFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr=4,kv=16384,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                21634 runs -    47.58 us/run - 134.22 MFLOP/run -   2.82 TFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr=4,kv=16384,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):               6714 runs -   151.99 us/run - 268.44 MFLOP/run -   1.77 TFLOPS

This is consistent with the ggml-cuda behavior and the mul_mat fallback.
@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Apr 6, 2025
@jeffbolznv jeffbolznv requested a review from 0cc4m April 6, 2025 18:10
@0cc4m
Copy link
Collaborator

0cc4m commented Apr 9, 2025

I don't see much of a difference on 3090, but no regression either:

Master:

model size params backend ngl fa test t/s
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 1 pp4096 3397.53 ± 0.00
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 1 pp8192 3084.63 ± 0.00
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 1 pp16384 2503.77 ± 0.00
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr=4,kv=4096,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 23848 runs -    45.95 us/run -  33.55 MFLOP/run - 730.26 GFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr=4,kv=4096,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):               17892 runs -    57.68 us/run -  67.11 MFLOP/run -   1.16 TFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr=4,kv=8192,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 17892 runs -    57.70 us/run -  67.11 MFLOP/run -   1.16 TFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr=4,kv=8192,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):               12682 runs -    80.59 us/run - 134.22 MFLOP/run -   1.67 TFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr=4,kv=16384,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                12682 runs -    79.57 us/run - 134.22 MFLOP/run -   1.69 TFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr=4,kv=16384,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):               8579 runs -   116.59 us/run - 268.44 MFLOP/run -   2.30 TFLOPS

PR:

model size params backend ngl fa test t/s
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 1 pp4096 3381.02 ± 0.00
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 1 pp8192 3040.78 ± 0.00
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 1 pp16384 2599.01 ± 0.00
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr=4,kv=4096,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 23848 runs -    45.65 us/run -  33.55 MFLOP/run - 734.96 GFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr=4,kv=4096,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):               17892 runs -    58.06 us/run -  67.11 MFLOP/run -   1.16 TFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr=4,kv=8192,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 17892 runs -    56.92 us/run -  67.11 MFLOP/run -   1.18 TFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr=4,kv=8192,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):               12682 runs -    80.56 us/run - 134.22 MFLOP/run -   1.67 TFLOPS
  FLASH_ATTN_EXT(hsk=64,hsv=64,nh=8,nr=4,kv=16384,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                13428 runs -    78.24 us/run - 134.22 MFLOP/run -   1.72 TFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr=4,kv=16384,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):               8579 runs -   119.57 us/run - 268.44 MFLOP/run -   2.25 TFLOPS

@0cc4m 0cc4m merged commit 7ecd780 into ggml-org:master Apr 9, 2025
47 checks passed
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants