Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

update asm pa #90

Merged
merged 4 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,7 @@ __pycache__
*.egg-info

# editor
.vscode
.vscode

# debug folder
debug
1 change: 0 additions & 1 deletion aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def pa_fwd_asm(
def paged_attention_rocm(
out: torch.Tensor,
exp_sums: torch.Tensor,
block_mapping: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
Expand Down
26 changes: 20 additions & 6 deletions aiter/ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def moe_smoothquant_fwd(


# following are pure torch implement
def get_dtype_max(dtype):
try:
dtypeMax = torch.finfo(dtype).max
except:
dtypeMax = torch.iinfo(dtype).max
return dtypeMax


def pertoken_quant(x, y_scale_dtype=torch.float, x_scale=None, quant_dtype=torch.int8):
if x_scale is None:
hidden_states = x
Expand All @@ -32,10 +40,7 @@ def pertoken_quant(x, y_scale_dtype=torch.float, x_scale=None, quant_dtype=torch
keepdim=True
)

try:
dtypeMax = torch.finfo(quant_dtype).max
except:
dtypeMax = torch.iinfo(quant_dtype).max
dtypeMax = get_dtype_max(quant_dtype)

per_token_scale = per_token_amax.to(dtype=torch.float32) / dtypeMax
per_token_scale[per_token_scale == 0] = 1
Expand All @@ -46,10 +51,19 @@ def pertoken_quant(x, y_scale_dtype=torch.float, x_scale=None, quant_dtype=torch
return y, y_scale


def per_tensor_quant(x, scale=None, scale_dtype=torch.float, quant_dtype=torch.int8):
if scale is None:
dtypeMax = get_dtype_max(quant_dtype)
scale = torch.abs(x.to(torch.float)).max() / dtypeMax
y = x/scale

return y.to(quant_dtype), scale.to(scale_dtype)


@compile_ops("module_quant")
def static_scaled_fp8_quant(
out: Tensor, input: Tensor, scale: Tensor
):...
): ...


@compile_ops("module_quant")
Expand All @@ -65,4 +79,4 @@ def dynamic_scaled_fp8_quant(
@compile_ops("module_quant")
def dynamic_per_token_scaled_fp8_quant(
out: Tensor, input: Tensor, scales: Tensor, scale_ub: Optional[Tensor] = None
):...
): ...
18 changes: 10 additions & 8 deletions aiter/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,29 @@ def get_trace_perf(prof, num_iters):
return df.at[avg_name, 'device_time_total']


def checkAllclose(a, b, rtol=1e-2, atol=1e-2, msg=''):
def checkAllclose(a, b, rtol=1e-2, atol=1e-2, msg='', printNum=8):
isClose = torch.isclose(a, b, rtol=rtol, atol=atol)
mask = ~isClose
if isClose.all():
logger.info(f'{msg}[checkAllclose {atol=} {rtol=} passed~]')
else:
percent = (a[mask]).numel()/a.numel()
num = mask.sum()
printNum = min(printNum, num)
percent = num/a.numel()
delta = (a-b)[mask]
if percent > 0.01:
logger.info(f'''{msg}[checkAllclose {atol=} {rtol=} failed!]
a: {a.shape}
{a[mask]}
b: {b.shape}
{b[mask]}
a : {a.shape}
{a[mask][:printNum]}
b : {b.shape}
{b[mask][:printNum]}
dtlta:
{delta}''')
{delta[:printNum]}''')
else:
logger.info(
f'''{msg}[checkAllclose {atol=} {rtol=} waring!] a and b results are not all close''')
logger.info(
f'-->max delta:{delta.max()}, delta details: {percent:.1%} ({(a[mask]).numel()} of {a.numel()}) elements')
f'-->max delta:{delta.max()}, delta details: {percent:.1%} ({num} of {a.numel()}) elements')


def tensor_dump(x: torch.tensor, name: str, dir='./'):
Expand Down
16 changes: 14 additions & 2 deletions csrc/kernels/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,20 @@ namespace vllm
}

// store the scale
k_dequant_scales[head_idx * max_kv_tokens + slot_idx] = k_token_scale;
v_dequant_scales[head_idx * max_kv_tokens + slot_idx] = v_token_scale;
if constexpr (asmLayout)
{
// [num_blocks, num_heads, block_size]
const int scale_idx = block_size * num_heads * block_idx +
block_size * head_idx +
block_offset;
k_dequant_scales[scale_idx] = k_token_scale;
v_dequant_scales[scale_idx] = v_token_scale;
}
else
{
k_dequant_scales[head_idx * max_kv_tokens + slot_idx] = k_token_scale;
v_dequant_scales[head_idx * max_kv_tokens + slot_idx] = v_token_scale;
}

// now let's store out
#pragma unroll
Expand Down
48 changes: 35 additions & 13 deletions csrc/py_itfs_cu/asm_pa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ struct __attribute__((packed)) KernelArgs
p3 _p12;
unsigned int mblk;
p3 _p13;
unsigned int batch;
unsigned int kv_nheads;
p3 _p14;
unsigned int Qs;
p3 _p15;
unsigned int Bs;
p3 _p16;
unsigned int KVs;
p3 _p17;
unsigned int GQA;
p3 _p18;
};

const float f_log2E = log2f(expf(1));
Expand All @@ -62,7 +64,7 @@ torch::Tensor pa_fwd(torch::Tensor &Q, // [num_seqs, num_heads, hea
__func__, " for now only support block_size == 16");

int dim = head_size;
int stride_Q = gqa_ratio * dim * Q.itemsize();
int stride_Q = Q.stride(0) * Q.itemsize();
int stride_KV_head = block_size * dim * K.itemsize();
int stride_KV_blk = stride_KV_head * num_kv_heads;
float k_log2e = f_log2E;
Expand All @@ -89,27 +91,47 @@ torch::Tensor pa_fwd(torch::Tensor &Q, // [num_seqs, num_heads, hea
}
args.sclg2e = k_scalar;
args.mblk = max_num_blocks;
args.batch = batch;
args.kv_nheads = num_kv_heads;
args.Qs = stride_Q;
args.Bs = stride_KV_blk;
args.KVs = stride_KV_head;
args.GQA = gqa_ratio;
// std::cout << "sclg2e: " << args.sclg2e << " mblk:" << args.mblk << " kv_nheads:" << args.kv_nheads << " Qs:" << args.Qs << " Bs:" << args.Bs << " KVs:" << args.KVs << std::endl;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
static AiterAsmKernel impl_a16w16("pa_kernel_func", "pa_a16w16.co");
static AiterAsmKernel impl_a16w8("pa_kernel_func", "pa_a16w8.co");
AiterAsmKernel *impl_ptr = &impl_a16w16;

AiterAsmKernel *impl_ptr = nullptr;
if (K_QScale)
impl_ptr = &impl_a16w8;
{
if (K.dtype() == at::ScalarType::Char)
{
static AiterAsmKernel impl_a16w8_i8("pa_a16w8_2tg_g8_i8", "pa_a16w8_2tg_g8_i8.co");
impl_ptr = &impl_a16w8_i8;
}
else if (K.dtype() == at::ScalarType::Float8_e4m3fnuz)
{
static AiterAsmKernel impl_a16w8_f8("pa_a16w8_2tg_g8_f8", "pa_a16w8_2tg_g8_f8.co");
impl_ptr = &impl_a16w8_f8;
}
}
else
{
TORCH_CHECK(Q.is_contiguous(),
__func__, ":a16w16 only support Q.is_contiguous() for now");
static AiterAsmKernel impl_a16w16("pa_kernel_func", "pa_a16w16.co");
impl_ptr = &impl_a16w16;
}
TORCH_CHECK(impl_ptr != nullptr,
__func__, ": unsupport current input type");

impl_ptr->launch_kernel({&args,
&arg_size,
1, // gdx
batch, // gdy
1, // gdz
256, // bdx: 4 wv64
1, // bdy
1, // bdz
num_kv_heads, // gdx
batch, // gdy
1, // gdz
256, // bdx: 4 wv64
1, // bdy
1, // bdz
stream});
return output;
}
Binary file added hsa/pa_a16w8_2tg_g8_f8.co
Binary file not shown.
Binary file added hsa/pa_a16w8_2tg_g8_i8.co
Binary file not shown.
30 changes: 22 additions & 8 deletions op_tests/test_kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@ def run_torch(key, value, k_cache, v_cache, slot_mapping, block_size, x, asm_lay
quant_dtype=quantCfg['quant_dtype'])
k_scale_ = k_scale_.permute(0, 1, 3, 2).view(
num_batch*num_tokens, num_heads).contiguous()

k_scale = k_scale.permute(1, 0).contiguous()
k_scale[slot_mapping] = k_scale_
k_scale = k_scale.permute(1, 0).contiguous()
if asm_layout:
k_scale = k_scale.permute(0, 2, 1).contiguous().view(-1, num_heads)
k_scale[slot_mapping] = k_scale_
k_scale = k_scale.view(
num_blocks, block_size, num_heads).permute(0, 2, 1).contiguous()
else:
k_scale = k_scale.permute(1, 0).contiguous()
k_scale[slot_mapping] = k_scale_
k_scale = k_scale.permute(1, 0).contiguous()

k_cache = k_cache.permute(0, 3, 1, 2, 4).contiguous().view(
-1, num_heads, head_size)
Expand All @@ -48,10 +53,17 @@ def run_torch(key, value, k_cache, v_cache, slot_mapping, block_size, x, asm_lay
quant_dtype=quantCfg['quant_dtype'])
v_scale_ = v_scale_.permute(0, 1, 3, 2).view(
num_batch*num_tokens, num_heads).contiguous()
if asm_layout:
v_scale = v_scale.permute(
0, 2, 1).contiguous().view(-1, num_heads)
v_scale[slot_mapping] = v_scale_
v_scale = v_scale.view(
num_blocks, block_size, num_heads).permute(0, 2, 1).contiguous()
else:
v_scale = v_scale.permute(1, 0).contiguous()
v_scale[slot_mapping] = v_scale_
v_scale = v_scale.permute(1, 0).contiguous()

v_scale = v_scale.permute(1, 0).contiguous()
v_scale[slot_mapping] = v_scale_
v_scale = v_scale.permute(1, 0).contiguous()
if asm_layout:
v_cache = v_cache.permute(0, 2, 4, 1, 3).contiguous().view(
-1, num_heads, head_size)
Expand Down Expand Up @@ -105,9 +117,11 @@ def test_reshape_and_cache(ctx_lens: int,
if asm_layout:
k_cache_shape = (bs*num_blocks, kvhead, head_size // x, block_size, x)
v_cache_shape = (bs*num_blocks, kvhead, block_size//x, head_size, x)
kv_scale_shape = (bs*num_blocks, kvhead, block_size)
else:
k_cache_shape = (bs*num_blocks, kvhead, head_size // x, block_size, x)
v_cache_shape = (bs*num_blocks, kvhead, head_size, block_size)
kv_scale_shape = (kvhead, bs*max_token_num_support)

# ##################################################### prefill part
qkv = torch.randn(
Expand All @@ -117,7 +131,7 @@ def test_reshape_and_cache(ctx_lens: int,
k_cache = torch.empty(k_cache_shape, dtype=DTyoe_KVCache, device=device)
v_cache = torch.empty(v_cache_shape, dtype=DTyoe_KVCache, device=device)
if quantCfg:
k_scale = torch.empty(kvhead, bs*max_token_num_support,
k_scale = torch.empty(kv_scale_shape,
dtype=quantCfg['y_scale_dtype'], device=key.device)
v_scale = torch.empty_like(k_scale)
quantCfg['k_scale'] = k_scale.clone()
Expand Down
Loading