diff --git a/.gitignore b/.gitignore index 90506dc7..814af71b 100644 --- a/.gitignore +++ b/.gitignore @@ -40,4 +40,7 @@ __pycache__ *.egg-info # editor -.vscode \ No newline at end of file +.vscode + +# debug folder +debug \ No newline at end of file diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index d5453596..f0195f32 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -54,7 +54,6 @@ def pa_fwd_asm( def paged_attention_rocm( out: torch.Tensor, exp_sums: torch.Tensor, - block_mapping: torch.Tensor, max_logits: torch.Tensor, tmp_out: torch.Tensor, query: torch.Tensor, diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index 30383f5a..65cb8bbb 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -19,6 +19,14 @@ def moe_smoothquant_fwd( # following are pure torch implement +def get_dtype_max(dtype): + try: + dtypeMax = torch.finfo(dtype).max + except: + dtypeMax = torch.iinfo(dtype).max + return dtypeMax + + def pertoken_quant(x, y_scale_dtype=torch.float, x_scale=None, quant_dtype=torch.int8): if x_scale is None: hidden_states = x @@ -32,10 +40,7 @@ def pertoken_quant(x, y_scale_dtype=torch.float, x_scale=None, quant_dtype=torch keepdim=True ) - try: - dtypeMax = torch.finfo(quant_dtype).max - except: - dtypeMax = torch.iinfo(quant_dtype).max + dtypeMax = get_dtype_max(quant_dtype) per_token_scale = per_token_amax.to(dtype=torch.float32) / dtypeMax per_token_scale[per_token_scale == 0] = 1 @@ -46,10 +51,19 @@ def pertoken_quant(x, y_scale_dtype=torch.float, x_scale=None, quant_dtype=torch return y, y_scale +def per_tensor_quant(x, scale=None, scale_dtype=torch.float, quant_dtype=torch.int8): + if scale is None: + dtypeMax = get_dtype_max(quant_dtype) + scale = torch.abs(x.to(torch.float)).max() / dtypeMax + y = x/scale + + return y.to(quant_dtype), scale.to(scale_dtype) + + @compile_ops("module_quant") def static_scaled_fp8_quant( out: Tensor, input: Tensor, scale: Tensor -):... +): ... @compile_ops("module_quant") @@ -65,4 +79,4 @@ def dynamic_scaled_fp8_quant( @compile_ops("module_quant") def dynamic_per_token_scaled_fp8_quant( out: Tensor, input: Tensor, scales: Tensor, scale_ub: Optional[Tensor] = None -):... +): ... diff --git a/aiter/test_common.py b/aiter/test_common.py index f2e74412..bf720a3b 100644 --- a/aiter/test_common.py +++ b/aiter/test_common.py @@ -106,27 +106,29 @@ def get_trace_perf(prof, num_iters): return df.at[avg_name, 'device_time_total'] -def checkAllclose(a, b, rtol=1e-2, atol=1e-2, msg=''): +def checkAllclose(a, b, rtol=1e-2, atol=1e-2, msg='', printNum=8): isClose = torch.isclose(a, b, rtol=rtol, atol=atol) mask = ~isClose if isClose.all(): logger.info(f'{msg}[checkAllclose {atol=} {rtol=} passed~]') else: - percent = (a[mask]).numel()/a.numel() + num = mask.sum() + printNum = min(printNum, num) + percent = num/a.numel() delta = (a-b)[mask] if percent > 0.01: logger.info(f'''{msg}[checkAllclose {atol=} {rtol=} failed!] - a: {a.shape} - {a[mask]} - b: {b.shape} - {b[mask]} + a : {a.shape} + {a[mask][:printNum]} + b : {b.shape} + {b[mask][:printNum]} dtlta: - {delta}''') + {delta[:printNum]}''') else: logger.info( f'''{msg}[checkAllclose {atol=} {rtol=} waring!] a and b results are not all close''') logger.info( - f'-->max delta:{delta.max()}, delta details: {percent:.1%} ({(a[mask]).numel()} of {a.numel()}) elements') + f'-->max delta:{delta.max()}, delta details: {percent:.1%} ({num} of {a.numel()}) elements') def tensor_dump(x: torch.tensor, name: str, dir='./'): diff --git a/csrc/kernels/cache_kernels.cu b/csrc/kernels/cache_kernels.cu index 10a06882..335778fc 100644 --- a/csrc/kernels/cache_kernels.cu +++ b/csrc/kernels/cache_kernels.cu @@ -469,8 +469,20 @@ namespace vllm } // store the scale - k_dequant_scales[head_idx * max_kv_tokens + slot_idx] = k_token_scale; - v_dequant_scales[head_idx * max_kv_tokens + slot_idx] = v_token_scale; + if constexpr (asmLayout) + { + // [num_blocks, num_heads, block_size] + const int scale_idx = block_size * num_heads * block_idx + + block_size * head_idx + + block_offset; + k_dequant_scales[scale_idx] = k_token_scale; + v_dequant_scales[scale_idx] = v_token_scale; + } + else + { + k_dequant_scales[head_idx * max_kv_tokens + slot_idx] = k_token_scale; + v_dequant_scales[head_idx * max_kv_tokens + slot_idx] = v_token_scale; + } // now let's store out #pragma unroll diff --git a/csrc/py_itfs_cu/asm_pa.cpp b/csrc/py_itfs_cu/asm_pa.cpp index b2a5cd11..7ce6a64a 100644 --- a/csrc/py_itfs_cu/asm_pa.cpp +++ b/csrc/py_itfs_cu/asm_pa.cpp @@ -28,7 +28,7 @@ struct __attribute__((packed)) KernelArgs p3 _p12; unsigned int mblk; p3 _p13; - unsigned int batch; + unsigned int kv_nheads; p3 _p14; unsigned int Qs; p3 _p15; @@ -36,6 +36,8 @@ struct __attribute__((packed)) KernelArgs p3 _p16; unsigned int KVs; p3 _p17; + unsigned int GQA; + p3 _p18; }; const float f_log2E = log2f(expf(1)); @@ -62,7 +64,7 @@ torch::Tensor pa_fwd(torch::Tensor &Q, // [num_seqs, num_heads, hea __func__, " for now only support block_size == 16"); int dim = head_size; - int stride_Q = gqa_ratio * dim * Q.itemsize(); + int stride_Q = Q.stride(0) * Q.itemsize(); int stride_KV_head = block_size * dim * K.itemsize(); int stride_KV_blk = stride_KV_head * num_kv_heads; float k_log2e = f_log2E; @@ -89,27 +91,47 @@ torch::Tensor pa_fwd(torch::Tensor &Q, // [num_seqs, num_heads, hea } args.sclg2e = k_scalar; args.mblk = max_num_blocks; - args.batch = batch; + args.kv_nheads = num_kv_heads; args.Qs = stride_Q; args.Bs = stride_KV_blk; args.KVs = stride_KV_head; + args.GQA = gqa_ratio; + // std::cout << "sclg2e: " << args.sclg2e << " mblk:" << args.mblk << " kv_nheads:" << args.kv_nheads << " Qs:" << args.Qs << " Bs:" << args.Bs << " KVs:" << args.KVs << std::endl; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - static AiterAsmKernel impl_a16w16("pa_kernel_func", "pa_a16w16.co"); - static AiterAsmKernel impl_a16w8("pa_kernel_func", "pa_a16w8.co"); - AiterAsmKernel *impl_ptr = &impl_a16w16; + AiterAsmKernel *impl_ptr = nullptr; if (K_QScale) - impl_ptr = &impl_a16w8; + { + if (K.dtype() == at::ScalarType::Char) + { + static AiterAsmKernel impl_a16w8_i8("pa_a16w8_2tg_g8_i8", "pa_a16w8_2tg_g8_i8.co"); + impl_ptr = &impl_a16w8_i8; + } + else if (K.dtype() == at::ScalarType::Float8_e4m3fnuz) + { + static AiterAsmKernel impl_a16w8_f8("pa_a16w8_2tg_g8_f8", "pa_a16w8_2tg_g8_f8.co"); + impl_ptr = &impl_a16w8_f8; + } + } + else + { + TORCH_CHECK(Q.is_contiguous(), + __func__, ":a16w16 only support Q.is_contiguous() for now"); + static AiterAsmKernel impl_a16w16("pa_kernel_func", "pa_a16w16.co"); + impl_ptr = &impl_a16w16; + } + TORCH_CHECK(impl_ptr != nullptr, + __func__, ": unsupport current input type"); impl_ptr->launch_kernel({&args, &arg_size, - 1, // gdx - batch, // gdy - 1, // gdz - 256, // bdx: 4 wv64 - 1, // bdy - 1, // bdz + num_kv_heads, // gdx + batch, // gdy + 1, // gdz + 256, // bdx: 4 wv64 + 1, // bdy + 1, // bdz stream}); return output; } \ No newline at end of file diff --git a/hsa/pa_a16w8_2tg_g8_f8.co b/hsa/pa_a16w8_2tg_g8_f8.co new file mode 100755 index 00000000..fd250081 Binary files /dev/null and b/hsa/pa_a16w8_2tg_g8_f8.co differ diff --git a/hsa/pa_a16w8_2tg_g8_i8.co b/hsa/pa_a16w8_2tg_g8_i8.co new file mode 100755 index 00000000..233d3c00 Binary files /dev/null and b/hsa/pa_a16w8_2tg_g8_i8.co differ diff --git a/op_tests/test_kvcache.py b/op_tests/test_kvcache.py index 31374b87..c8cd5bb5 100644 --- a/op_tests/test_kvcache.py +++ b/op_tests/test_kvcache.py @@ -30,10 +30,15 @@ def run_torch(key, value, k_cache, v_cache, slot_mapping, block_size, x, asm_lay quant_dtype=quantCfg['quant_dtype']) k_scale_ = k_scale_.permute(0, 1, 3, 2).view( num_batch*num_tokens, num_heads).contiguous() - - k_scale = k_scale.permute(1, 0).contiguous() - k_scale[slot_mapping] = k_scale_ - k_scale = k_scale.permute(1, 0).contiguous() + if asm_layout: + k_scale = k_scale.permute(0, 2, 1).contiguous().view(-1, num_heads) + k_scale[slot_mapping] = k_scale_ + k_scale = k_scale.view( + num_blocks, block_size, num_heads).permute(0, 2, 1).contiguous() + else: + k_scale = k_scale.permute(1, 0).contiguous() + k_scale[slot_mapping] = k_scale_ + k_scale = k_scale.permute(1, 0).contiguous() k_cache = k_cache.permute(0, 3, 1, 2, 4).contiguous().view( -1, num_heads, head_size) @@ -48,10 +53,17 @@ def run_torch(key, value, k_cache, v_cache, slot_mapping, block_size, x, asm_lay quant_dtype=quantCfg['quant_dtype']) v_scale_ = v_scale_.permute(0, 1, 3, 2).view( num_batch*num_tokens, num_heads).contiguous() + if asm_layout: + v_scale = v_scale.permute( + 0, 2, 1).contiguous().view(-1, num_heads) + v_scale[slot_mapping] = v_scale_ + v_scale = v_scale.view( + num_blocks, block_size, num_heads).permute(0, 2, 1).contiguous() + else: + v_scale = v_scale.permute(1, 0).contiguous() + v_scale[slot_mapping] = v_scale_ + v_scale = v_scale.permute(1, 0).contiguous() - v_scale = v_scale.permute(1, 0).contiguous() - v_scale[slot_mapping] = v_scale_ - v_scale = v_scale.permute(1, 0).contiguous() if asm_layout: v_cache = v_cache.permute(0, 2, 4, 1, 3).contiguous().view( -1, num_heads, head_size) @@ -105,9 +117,11 @@ def test_reshape_and_cache(ctx_lens: int, if asm_layout: k_cache_shape = (bs*num_blocks, kvhead, head_size // x, block_size, x) v_cache_shape = (bs*num_blocks, kvhead, block_size//x, head_size, x) + kv_scale_shape = (bs*num_blocks, kvhead, block_size) else: k_cache_shape = (bs*num_blocks, kvhead, head_size // x, block_size, x) v_cache_shape = (bs*num_blocks, kvhead, head_size, block_size) + kv_scale_shape = (kvhead, bs*max_token_num_support) # ##################################################### prefill part qkv = torch.randn( @@ -117,7 +131,7 @@ def test_reshape_and_cache(ctx_lens: int, k_cache = torch.empty(k_cache_shape, dtype=DTyoe_KVCache, device=device) v_cache = torch.empty(v_cache_shape, dtype=DTyoe_KVCache, device=device) if quantCfg: - k_scale = torch.empty(kvhead, bs*max_token_num_support, + k_scale = torch.empty(kv_scale_shape, dtype=quantCfg['y_scale_dtype'], device=key.device) v_scale = torch.empty_like(k_scale) quantCfg['k_scale'] = k_scale.clone() diff --git a/op_tests/test_pa.py b/op_tests/test_pa.py index 954ecef4..aef296b0 100644 --- a/op_tests/test_pa.py +++ b/op_tests/test_pa.py @@ -21,12 +21,13 @@ } ck_naive_quant_algo = [ 'NO', - 'KV_8BIT_PERHEAD', + 'KV_8BIT_PER_HEAD', # // FP8/INT8 quant for KVCache, per-token quant # // [num_tokens, nhead, hdim] -> [nhead, num_tokens] - 'KV_8BIT_PERTOKEN', + 'KV_8BIT_PER_TOKEN', # // same as 8bit per token quant but 4 bit - 'KV_4BIT_PERTOKEN', + 'KV_4BIT_PER_TOKEN', + 'KV_8BIT_PER_TENSOR', ] @@ -75,32 +76,32 @@ def kv_cache_factory( scale = head_size**-0.5 x = 16 // torch_dtype.itemsize - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_caches: List[torch.Tensor] = [] + k_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + k_caches: List[torch.Tensor] = [] for _ in range(num_layers): - key_cache = torch.empty(size=key_cache_shape, - dtype=torch_dtype, - device=device) + k_cache = torch.empty(size=k_cache_shape, + dtype=torch_dtype, + device=device) if cache_dtype in ["auto", "half", "bfloat16", "float"]: - key_cache.uniform_(*uniform_range) + k_cache.uniform_(*uniform_range) else: raise ValueError( f"Does not support key cache of type {cache_dtype}") - key_caches.append(key_cache) + k_caches.append(k_cache) - value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_caches: List[torch.Tensor] = [] + v_cache_shape = (num_blocks, num_heads, head_size, block_size) + v_caches: List[torch.Tensor] = [] for _ in range(num_layers): - value_cache = torch.empty(size=value_cache_shape, - dtype=torch_dtype, - device=device) + v_cache = torch.empty(size=v_cache_shape, + dtype=torch_dtype, + device=device) if cache_dtype in ["auto", "half", "bfloat16", "float"]: - value_cache.uniform_(*uniform_range) + v_cache.uniform_(*uniform_range) else: raise ValueError( f"Does not support value cache of type {cache_dtype}") - value_caches.append(value_cache) - return key_caches, value_caches + v_caches.append(v_cache) + return k_caches, v_caches FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -150,58 +151,58 @@ def ref_masked_attention( def pertoken_quant_kvcache_symm( # [num_blocks, num_heads, head_size // x, block_size, x] - key_cache: torch.Tensor, + k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] - value_cache: torch.Tensor, + v_cache: torch.Tensor, quant_dtype: torch.dtype, # e.g. torch.float8_e4m3fnuz scale_dtype: torch.dtype = torch.float32 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - num_blocks = key_cache.shape[0] - num_heads = key_cache.shape[1] - head_dim = value_cache.shape[2] - block_size = value_cache.shape[3] - # x = key_cache.shape[4] + num_blocks = k_cache.shape[0] + num_heads = k_cache.shape[1] + head_dim = v_cache.shape[2] + block_size = v_cache.shape[3] + # x = k_cache.shape[4] total_tokens = num_blocks * block_size - # print(f"{key_cache.shape=}{key_cache.stride()=}") - # print(f"{value_cache.shape=}{value_cache.stride()=}") + # print(f"{k_cache.shape=}{k_cache.stride()=}") + # print(f"{v_cache.shape=}{v_cache.stride()=}") - key_cache_permute = key_cache.permute(0, 1, 3, 2, 4).reshape( + k_cache_permute = k_cache.permute(0, 1, 3, 2, 4).reshape( num_blocks, num_heads, block_size, -1).contiguous() - value_cache_permute = value_cache.permute(0, 1, 3, 2).reshape( + v_cache_permute = v_cache.permute(0, 1, 3, 2).reshape( num_blocks, num_heads, block_size, -1).contiguous() - k_quant, k_scale = pertoken_quant( - key_cache_permute, scale_dtype, quant_dtype=quant_dtype) - v_quant, v_scale = pertoken_quant( - value_cache_permute, scale_dtype, quant_dtype=quant_dtype) + k_quant, k_scale_asm = pertoken_quant( + k_cache_permute, scale_dtype, quant_dtype=quant_dtype) + v_quant, v_scale_asm = pertoken_quant( + v_cache_permute, scale_dtype, quant_dtype=quant_dtype) # NOTE: quant_x and original x could be different quant_x = 16 // quant_dtype.itemsize k_quant = k_quant.view(num_blocks, num_heads, block_size, head_dim // quant_x, quant_x).permute(0, 1, 3, 2, 4).contiguous() - k_scale = k_scale.permute(1, 0, 2, 3).view( - num_heads, total_tokens).contiguous() + k_scale = k_scale_asm.permute(1, 0, 2, 3).contiguous().view( + num_heads, total_tokens) v_quant = v_quant.view(num_blocks, num_heads, block_size, head_dim).permute(0, 1, 3, 2).contiguous() - v_scale = v_scale.permute(1, 0, 2, 3).view( - num_heads, total_tokens).contiguous() + v_scale = v_scale_asm.permute(1, 0, 2, 3).contiguous().view( + num_heads, total_tokens) # print(f"{k_quant.shape=}{k_quant.stride()=}") # print(f"{k_scale.shape=}{k_scale.stride()=}") # print(f"{v_quant.shape=}{v_quant.stride()=}") # print(f"{v_scale.shape=}{v_scale.stride()=}") - # print(f"key_cache_permute:{key_cache_permute[0, :, :, :]}, k_quant:{k_quant[0, :, :, :, :]}, k_scale:{k_scale[:, 0]}") + # print(f"k_cache_permute:{k_cache_permute[0, :, :, :]}, k_quant:{k_quant[0, :, :, :, :]}, k_scale:{k_scale[:, 0]}") - return k_quant, k_scale, v_quant, v_scale + return k_quant, k_scale, v_quant, v_scale, k_scale_asm, v_scale_asm # @perftest() def run_native(query, - key_cache, - value_cache, + k_cache, + v_cache, block_tables, seq_lens, max_seq_len, @@ -214,9 +215,9 @@ def run_native(query, num_queries_per_kv): output = torch.zeros_like(query) num_query_heads = query.shape[1] - num_kv_heads = value_cache.shape[1] - head_size = value_cache.shape[2] - block_size = value_cache.shape[3] + num_kv_heads = v_cache.shape[1] + head_size = v_cache.shape[2] + block_size = v_cache.shape[3] num_seqs = query.shape[0] block_tables_lst = block_tables.cpu().tolist() @@ -232,11 +233,11 @@ def run_native(query, block_number = int(block_table[j // block_size]) block_offset = j % block_size - k = key_cache[block_number, :, :, block_offset, :] + k = k_cache[block_number, :, :, block_offset, :] k = k.reshape(num_kv_heads, head_size) keys_lst.append(k) - v = value_cache[block_number, :, :, block_offset] + v = v_cache[block_number, :, :, block_offset] values_lst.append(v) keys = torch.stack(keys_lst, dim=0) values = torch.stack(values_lst, dim=0) @@ -261,21 +262,21 @@ def run_native(query, @perftest() def run_aiter(query, - key_cache, - value_cache, - block_tables, - seq_lens, - max_seq_len, - kv_cache_dtype, - num_kv_heads, - scale, - alibi_slopes, - k_scale, - v_scale,): + k_cache, + v_cache, + block_tables, + seq_lens, + max_seq_len, + kv_cache_dtype, + num_kv_heads, + scale, + alibi_slopes, + k_scale, + v_scale,): return ops.PagedAttention.forward_decode( query, - key_cache, - value_cache, + k_cache, + v_cache, block_tables, seq_lens, max_seq_len, @@ -290,25 +291,25 @@ def run_aiter(query, @perftest() def run_aiter_naive(query, - key_cache, - value_cache, - block_tables, - seq_lens, - k_dequant_scales, - v_dequant_scales, - max_seq_len, - kv_cache_dtype, - num_kv_heads, - scale, - alibi_slopes, - k_scale, - v_scale, - block_size, - quant_algo=0): + k_cache, + v_cache, + block_tables, + seq_lens, + k_dequant_scales, + v_dequant_scales, + max_seq_len, + kv_cache_dtype, + num_kv_heads, + scale, + alibi_slopes, + k_scale, + v_scale, + block_size, + quant_algo=0): return aiter.pa_fwd_naive( query, - key_cache, - value_cache, + k_cache, + v_cache, block_tables, seq_lens, k_dequant_scales, @@ -325,22 +326,22 @@ def run_aiter_naive(query, @perftest() def run_aiter_asm(query, - key_cache, - value_cache, - block_tables, - seq_lens, - max_seq_len, - kv_cache_dtype, - num_kv_heads, - scale, - alibi_slopes, - max_num_blocks, - k_scale=None, - v_scale=None): + k_cache, + v_cache, + block_tables, + seq_lens, + max_seq_len, + kv_cache_dtype, + num_kv_heads, + scale, + alibi_slopes, + max_num_blocks, + k_scale=None, + v_scale=None): return aiter.pa_fwd_asm( query, - key_cache, - value_cache, + k_cache, + v_cache, block_tables, seq_lens, max_num_blocks, @@ -350,8 +351,8 @@ def run_aiter_asm(query, def dump_input(query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, max_seq_len: int, @@ -360,14 +361,21 @@ def dump_input(query: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], k_scale: float, - v_scale: float,): - tensor_dump(query, 'Q') + v_scale: float, + out_golden, + out_test): + path = '/mnt/raid0/ljin1/dk/ater/debug_ctx7' + tensor_dump(query, 'Q', path) # qbk = tensor_load('Q.bin') # checkAllclose(query, qbk) - tensor_dump(key_cache, 'K_cache') - tensor_dump(value_cache, 'V_cache') - tensor_dump(block_tables, 'block_tables') - tensor_dump(seq_lens, 'seq_lens') + tensor_dump(k_cache, 'K_cache', path) + tensor_dump(v_cache, 'V_cache', path) + tensor_dump(block_tables, 'block_tables', path) + tensor_dump(seq_lens, 'seq_lens', path) + tensor_dump(k_scale, 'k_scale', path) + tensor_dump(v_scale, 'v_scale', path) + tensor_dump(out_golden, 'out_golden', path) + tensor_dump(out_test, 'out_test', path) def load_input(): @@ -440,13 +448,16 @@ def test_paged_attention( if debug_mode == VERIFY: (query, - key_cache, - value_cache, + k_cache, + v_cache, block_tables, seq_lens, out_golden) = load_input() else: - query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query = torch.empty_strided( + (num_seqs, num_query_heads, head_size), + ((num_query_heads+2*num_kv_heads)*head_size, head_size, 1), + dtype=dtype) query.uniform_(*uniform_range) # seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] @@ -465,16 +476,16 @@ def test_paged_attention( block_tables = torch.tensor(block_tables_lst, dtype=torch.int) # Create the KV caches. - key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - device) - key_cache, value_cache = key_caches[0], value_caches[0] + k_caches, v_caches = kv_cache_factory(num_blocks, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + k_cache, v_cache = k_caches[0], v_caches[0] out_aiter, time_aiter = run_aiter( query, - key_cache, - value_cache, + k_cache, + v_cache, block_tables, seq_lens, max_seq_len, @@ -488,56 +499,99 @@ def test_paged_attention( if debug_mode != VERIFY: out_golden = out_aiter checkAllclose(out_golden, out_aiter, - msg=f'golden vs aiter_shomy:{time_aiter}') + msg=f'golden vs aiter_shomy:{time_aiter:.2f} us......') # tensor_dump(out_aiter, 'out_aiter') - out_aiter_asm, time_aiter_asm = run_aiter_asm( - query, - key_cache, - asm_V_shuffle(value_cache), - block_tables, - seq_lens, - max_seq_len, - kv_cache_dtype, - num_kv_heads, - scale, - alibi_slopes, - max_num_blocks_per_seq - ) - - checkAllclose(out_golden, out_aiter_asm, - msg=f'golden vs aiter_asm:{time_aiter_asm}') - # tensor_dump(out_aiter, 'out_aiter') - - for quant_algo_, cache_type_ in [(0, key_cache.dtype), (2, torch.float8_e4m3fnuz), (2, torch.int8)]: - if quant_algo_ == 0: - k_quant_, k_scale_, v_quant_, v_scale_ = key_cache, torch.empty( - (0)), value_cache, torch.empty((0)) - else: - k_quant_, k_scale_, v_quant_, v_scale_ = pertoken_quant_kvcache_symm( - key_cache, value_cache, quant_dtype=cache_type_) - out_aiter_naive, time_aiter_naive = run_aiter_naive( - query, - k_quant_, - v_quant_, + if num_kv_heads == 1: + out_aiter_asm, time_aiter_asm = run_aiter_asm( + query.contiguous(), # this kernel need contiguous buffer + k_cache, + asm_V_shuffle(v_cache), block_tables, seq_lens, - k_scale_, - v_scale_, max_seq_len, kv_cache_dtype, num_kv_heads, scale, alibi_slopes, - k_scale, - v_scale, - block_size, - quant_algo_ + max_num_blocks_per_seq ) - checkAllclose(out_aiter_asm, out_aiter_naive, - msg=f'golden vs ck_naive(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_}):{time_aiter_naive}') - if cache_type_ == torch.int8: + checkAllclose(out_golden, out_aiter_asm, + msg=f'golden vs aiter_asm:{time_aiter_asm:.2f} us......') + # tensor_dump(out_aiter, 'out_aiter') + + for quant_algo_, cache_type_ in [ + (0, k_cache.dtype), + (2, torch.float8_e4m3fnuz), + (2, torch.int8), + (4, torch.float8_e4m3fnuz), + ]: + quant_algo = ck_naive_quant_algo[quant_algo_] + if quant_algo == "NO": + k_quant_, k_scale_, v_quant_, v_scale_ = k_cache, torch.empty( + (0)), v_cache, torch.empty((0)) + elif quant_algo == "KV_8BIT_PER_TOKEN": + k_quant_, k_scale_, v_quant_, v_scale_, k_scale_asm, v_scale_asm = pertoken_quant_kvcache_symm( + k_cache, v_cache, quant_dtype=cache_type_) + elif quant_algo == "KV_8BIT_PER_TENSOR": + k_quant_, k_scale_ = aiter.per_tensor_quant( + k_cache, quant_dtype=cache_type_) + + x = 16 // cache_type_.itemsize + k_quant_ = k_quant_.permute(0, 1, 3, 2, 4).reshape( + num_blocks, num_kv_heads, block_size, -1).contiguous() + k_quant_ = k_quant_.view(num_blocks, num_kv_heads, block_size, head_size // + x, x).permute(0, 1, 3, 2, 4).contiguous() + + v_quant_, v_scale_ = aiter.per_tensor_quant( + v_cache, quant_dtype=cache_type_) + + k_scale_asm = torch.empty(num_blocks, num_kv_heads, block_size, + dtype=torch.float32, device=device) + v_scale_asm = torch.empty(num_blocks, num_kv_heads, block_size, + dtype=torch.float32, device=device) + k_scale_asm.fill_(k_scale_.item()) + v_scale_asm.fill_(v_scale_.item()) + + out_aiter, time_aiter = run_aiter( + query, + k_quant_, + v_quant_, + block_tables, + seq_lens, + max_seq_len, + 'fp8', + num_kv_heads, + scale, + alibi_slopes, + k_scale_.item(), + v_scale_.item(), + ) + checkAllclose(out_golden, out_aiter, + msg=f'golden vs shomy:{time_aiter:.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})') + # out_aiter_naive, time_aiter_naive = run_aiter_naive( + # query, + # k_quant_, + # v_quant_, + # block_tables, + # seq_lens, + # k_scale_, + # v_scale_, + # max_seq_len, + # kv_cache_dtype, + # num_kv_heads, + # scale, + # alibi_slopes, + # k_scale, + # v_scale, + # block_size, + # quant_algo_ + # ) + # checkAllclose(out_aiter_asm, out_aiter_naive, + # msg=f'golden vs ck_naive(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_}):{time_aiter_naive:.2f} us......') + + if quant_algo_ != 0: out_aiter_asm, time_aiter_asm = run_aiter_asm( query, k_quant_, @@ -550,15 +604,31 @@ def test_paged_attention( scale, alibi_slopes, max_num_blocks_per_seq, - k_scale_, - v_scale_, + k_scale_asm, + v_scale_asm, ) checkAllclose(out_golden, out_aiter_asm, - msg=f'golden vs aiter_asm(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_}):{time_aiter_asm}') + msg=f'golden vs aiter_asm:{time_aiter_asm:.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})') + # if quant_algo == "KV_8BIT_PER_TENSOR": + # dump_input(query, + # k_quant_, + # asm_V_shuffle(v_quant_), + # block_tables, + # seq_lens, + # max_seq_len, + # kv_cache_dtype, + # num_kv_heads, + # scale, + # alibi_slopes, + # k_scale_asm, + # v_scale_asm, + # out_golden, + # out_aiter_asm) + if debug_mode == DUMP: dump_input(query, - key_cache, - value_cache, + k_cache, + v_cache, block_tables, seq_lens, max_seq_len, @@ -567,12 +637,13 @@ def test_paged_attention( scale, alibi_slopes, k_scale, - v_scale,) + v_scale, + out_golden) # out_native, time_native = run_native( # query, - # key_cache, - # value_cache, + # k_cache, + # v_cache, # block_tables, # seq_lens, # max_seq_len, @@ -595,6 +666,7 @@ def test_paged_attention( print(f'finish~ {ctx_lens=}, {num_seqs=}, {num_heads=}, {head_size=}, {use_alibi=}, {block_size=}, {dtype=}, {kv_cache_dtype=}\n') -for ctx_len in [1, 26, 128, 4097]: - test_paged_attention(ctx_len, 128, (8, 1), 128, False, 16, - torch.bfloat16, "auto", 0, "cuda:0") +for num_heads in [(4, 1), (8, 1), (32, 8)]: + for ctx_len in [7, 26, 57, 66, 109, 128, 257, 282, 4097]: + test_paged_attention(ctx_len, 128, num_heads, 128, False, 16, + torch.bfloat16, "auto", 0, "cuda:0")