Skip to content

Commit

Permalink
Support FP8 and INT8 KV cache according to ROCm/aiter#90
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd committed Feb 7, 2025
1 parent 3caff4d commit c127e9a
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion vllm/attention/ops/paged_attn_ater.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,15 @@ def forward_decode(
# For context len > 8192, use V2 kernel to avoid shared memory shortage.

max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
if kv_cache_dtype not in ['int8', 'fp8', 'fp8', 'fp8_e5m2', 'fp8_e4m3']:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
k_scale, v_scale = (None, None)
query = query.contiguous()
elif "fp8" in kv_cache_dtype:
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)
else:
key_cache = key_cache.view(torch.int8)
value_cache = value_cache.view(torch.int8)
dtype=out.dtype
aiter.pa_fwd_asm(query.to(torch.bfloat16), key_cache, value_cache, block_tables, seq_lens, max_num_blocks_per_seq, k_scale, v_scale,out)
if dtype==torch.float16:
Expand Down

0 comments on commit c127e9a

Please # to comment.