diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 8f538c21f7f7e..1d59a01422412 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -1,6 +1,7 @@ import argparse import time from datetime import datetime +from itertools import product from typing import Any, Dict, List, Tuple, TypedDict import ray @@ -11,7 +12,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.platforms import current_platform -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, is_navi + +FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm( +) and not is_navi() else torch.float8_e4m3fn class BenchmarkConfig(TypedDict): @@ -80,8 +84,8 @@ def benchmark_config( a1_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32) - w1 = w1.to(torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fn) + w1 = w1.to(FP8_DTYPE) + w2 = w2.to(FP8_DTYPE) input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) @@ -141,28 +145,172 @@ def run(): return avg -def get_configs_compute_bound() -> List[Dict[str, int]]: - # Reduced search space for faster tuning. - # TODO(woosuk): Increase the search space and use a performance model to - # prune the search space. +def get_rocm_tuning_space(use_fp16): + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + if not use_fp16: + block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + num_stage_range = [2] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] if use_fp16 else [] + kpack_range = [1, 2] if use_fp16 else [] + + param_ranges = { + "BLOCK_SIZE_M": block_mn_range, + "BLOCK_SIZE_N": block_mn_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + "waves_per_eu": waves_per_eu_range, + } + if use_fp16: + param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range + param_ranges["kpack"] = kpack_range + + return param_ranges + + +def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]: configs: List[BenchmarkConfig] = [] - for num_stages in [2, 3, 4, 5]: - for block_m in [16, 32, 64, 128, 256]: - for block_k in [64, 128, 256]: - for block_n in [32, 64, 128, 256]: - for num_warps in [4, 8]: - for group_size in [1, 16, 32, 64]: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_size, - "num_warps": num_warps, - "num_stages": num_stages, - }) + + if current_platform.is_rocm(): + param_ranges = get_rocm_tuning_space(use_fp16) + else: + # Reduced search space for faster tuning. + # TODO(woosuk): Increase the search space and use a performance model to + # prune the search space. + block_m_range = [16, 32, 64, 128, 256] + block_n_range = [32, 64, 128, 256] + block_k_range = [64, 128, 256] + num_warps_range = [4, 8] + group_m_range = [1, 16, 32, 64] + num_stage_range = [2, 3, 4, 5] + + param_ranges = { + "BLOCK_SIZE_M": block_m_range, + "BLOCK_SIZE_N": block_n_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + } + + keys, values = zip(*param_ranges.items()) + for config_values in product(*values): + config = dict(zip(keys, config_values)) + configs.append(config) return configs +def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, + search_space, is_fp16): + N1, K1 = shard_intermediate_size, hidden_size + N2, K2 = hidden_size, shard_intermediate_size // 2 + pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space, + is_fp16) + pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space, + is_fp16) + search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) + return search_space + + +# The following code is inspired by ROCm/Triton GEMM tuning script: +# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89 +def prune_rocm_configs(M, N, K, configs, is_fp16=True): + pruned_configs = [] + elemBytes_a = 2 if is_fp16 else 1 + elemBytes_b = 2 if is_fp16 else 1 + + mfma = 16 if M < 32 or N < 32 else 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + + if is_fp16: + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elements per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = config.get("SPLIT_K", 1) + GROUP_M = config.get("GROUP_SIZE_M") + if is_fp16: + if (matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N): + continue + if (matrix_instr_nonkdim >= M + and matrix_instr_nonkdim != BLOCK_SIZE_M): + continue + if (matrix_instr_nonkdim >= N + and matrix_instr_nonkdim != BLOCK_SIZE_N): + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: + continue + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + + +def merge_unique_dicts(list1, list2): + result = [] + combined_list = list1.copy() + combined_list.extend(list2) + for dictionary in combined_list: + if dictionary not in result: + result.append(dictionary) + return result + + @ray.remote(num_gpus=1) class BenchmarkWorker: @@ -170,6 +318,10 @@ def __init__(self, seed: int) -> None: torch.set_default_device("cuda") current_platform.seed_everything(seed) self.seed = seed + # Get the device ID to allocate tensors and kernels + # on the respective GPU. This is required for Ray to work + # correctly with multi-GPU tuning on the ROCm platform. + self.device_id = int(ray.get_gpu_ids()[0]) def benchmark( self, @@ -217,25 +369,33 @@ def tune( ) -> Dict[str, int]: best_config = None best_time = float("inf") - for config in tqdm(search_space): - try: - kernel_time = benchmark_config(config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - num_iters=10) - except triton.runtime.autotuner.OutOfResources: - # Some configurations may be invalid and fail to compile. - continue - - if kernel_time < best_time: - best_time = kernel_time - best_config = config + if current_platform.is_rocm(): + is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) + search_space = prune_rocm_search_space(num_tokens, + shard_intermediate_size, + hidden_size, search_space, + is_fp16) + + with torch.cuda.device(self.device_id): + for config in tqdm(search_space): + try: + kernel_time = benchmark_config(config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=20) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config now = datetime.now() print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") assert best_config is not None @@ -244,12 +404,27 @@ def tune( def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: return { - "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], - "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], - "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], - "GROUP_SIZE_M": config["GROUP_SIZE_M"], - "num_warps": config["num_warps"], - "num_stages": config["num_stages"], + "BLOCK_SIZE_M": + config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": + config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": + config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": + config["GROUP_SIZE_M"], + "num_warps": + config["num_warps"], + "num_stages": + config["num_stages"], + **({ + "waves_per_eu": config["waves_per_eu"] + } if "waves_per_eu" in config else {}), + **({ + "matrix_instr_nonkdim": config["matrix_instr_nonkdim"] + } if "matrix_instr_nonkdim" in config else {}), + **({ + "kpack": config["kpack"] + } if "kpack" in config else {}), } @@ -294,7 +469,7 @@ def main(args: argparse.Namespace): shard_intermediate_size = 2 * intermediate_size // args.tp_size hidden_size = config.hidden_size - dtype = config.torch_dtype + dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" @@ -322,7 +497,8 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: return ray.get(outputs) if args.tune: - search_space = get_configs_compute_bound() + is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) + search_space = get_configs_compute_bound(is_fp16) print(f"Start tuning over {len(search_space)} configurations...") start = time.time()