From 9db52eab3dc0b7b2cf30fa4399d569131e90c2d4 Mon Sep 17 00:00:00 2001 From: rasmith Date: Fri, 6 Sep 2024 17:26:09 -0500 Subject: [PATCH] [Kernel] [Triton] Memory optimization for awq_gemm and awq_dequantize, 2x throughput (#8248) --- .../layers/quantization/awq_triton.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py index ad706f28a742b..d0b210c3a2747 100644 --- a/vllm/model_executor/layers/quantization/awq_triton.py +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -22,7 +22,7 @@ def awq_dequantize_kernel( # Compute offsets and masks for qweight_ptr. offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) - offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] masks_y = offsets_y < num_rows @@ -43,6 +43,9 @@ def awq_dequantize_kernel( # Load the weights. iweights = tl.load(qweight_ptr + offsets, masks) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] # that will map given indices to the correct order. @@ -59,9 +62,8 @@ def awq_dequantize_kernel( iweights = (iweights >> shifts) & 0xF # Compute zero offsets and masks. - zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + - tl.arange(0, BLOCK_SIZE_Y) // group_size) - zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 + zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] zero_masks_y = zero_offsets_y < num_rows // group_size @@ -70,13 +72,16 @@ def awq_dequantize_kernel( # Load the zeros. zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) # Unpack and reorder: shift out the correct 4-bit value and mask. zeros = (zeros >> shifts) & 0xF # Compute scale offsets and masks. - scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + - tl.arange(0, BLOCK_SIZE_Y) // group_size) + scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8)) scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + @@ -87,6 +92,7 @@ def awq_dequantize_kernel( # Load the scales. scales = tl.load(scales_ptr + scale_offsets, scale_masks) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) # Dequantize. iweights = (iweights - zeros) * scales @@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) masks_am = offsets_am < M - offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) + - tl.arange(0, BLOCK_SIZE_N) // 8) + offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) masks_bn = offsets_bn < N // 8 - offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) + - tl.arange(0, BLOCK_SIZE_N) // 8) + offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) masks_zn = offsets_zn < N // 8 offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, masks_b = masks_k[:, None] & masks_bn[None, :] b = tl.load(b_ptrs, mask=masks_b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) # Dequantize b. offsets_szk = ( (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size + - tl.arange(0, BLOCK_SIZE_K) // group_size) + tl.arange(0, 1)) offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] masks_zk = offsets_szk < K // group_size masks_z = masks_zk[:, None] & masks_zn[None, :] zeros_ptrs = zeros_ptr + offsets_z zeros = tl.load(zeros_ptrs, mask=masks_z) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N)) offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] masks_sk = offsets_szk < K // group_size masks_s = masks_sk[:, None] & masks_sn[None, :] scales_ptrs = scales_ptr + offsets_s scales = tl.load(scales_ptrs, mask=masks_s) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) b = (b >> shifts) & 0xF zeros = (zeros >> shifts) & 0xF