Skip to content

Commit 9c229c9

Browse files
committed
simplify
1 parent a7f4856 commit 9c229c9

File tree

1 file changed

+49
-70
lines changed

1 file changed

+49
-70
lines changed

include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh

Lines changed: 49 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,47 @@ cudaError_t moereduction_allreduce_fusion_kernel_launcher(
11031103
return cudaSuccess;
11041104
}
11051105

1106+
#define DISPATCH_BOOL_(expr, const_expr, ...) \
1107+
[&]() -> cudaError_t { \
1108+
if (expr) { \
1109+
constexpr bool const_expr = true; \
1110+
return __VA_ARGS__(); \
1111+
} else { \
1112+
constexpr bool const_expr = false; \
1113+
return __VA_ARGS__(); \
1114+
} \
1115+
}()
1116+
1117+
#define _DISPATCH_MOEREDUCTION_CASE(n_ranks_val, N_RANKS_VAR, ar, res, rms, quant, AR, RES, RMS, \
1118+
QUANT, ...) \
1119+
case n_ranks_val: { \
1120+
constexpr int N_RANKS_VAR = n_ranks_val; \
1121+
return DISPATCH_BOOL_(ar, AR, [&]() -> cudaError_t { \
1122+
return DISPATCH_BOOL_(res, RES, [&]() -> cudaError_t { \
1123+
return DISPATCH_BOOL_(rms, RMS, [&]() -> cudaError_t { \
1124+
return DISPATCH_BOOL_(quant, QUANT, [&]() -> cudaError_t { return __VA_ARGS__(); }); \
1125+
}); \
1126+
}); \
1127+
}); \
1128+
}
1129+
1130+
#define DISPATCH_MOEREDUCTION(n_ranks, ar, res, rms, quant, N_RANKS, AR, RES, RMS, QUANT, ...) \
1131+
[&]() -> cudaError_t { \
1132+
switch (n_ranks) { \
1133+
_DISPATCH_MOEREDUCTION_CASE(2, N_RANKS, ar, res, rms, quant, AR, RES, RMS, QUANT, \
1134+
__VA_ARGS__) \
1135+
_DISPATCH_MOEREDUCTION_CASE(4, N_RANKS, ar, res, rms, quant, AR, RES, RMS, QUANT, \
1136+
__VA_ARGS__) \
1137+
_DISPATCH_MOEREDUCTION_CASE(8, N_RANKS, ar, res, rms, quant, AR, RES, RMS, QUANT, \
1138+
__VA_ARGS__) \
1139+
_DISPATCH_MOEREDUCTION_CASE(16, N_RANKS, ar, res, rms, quant, AR, RES, RMS, QUANT, \
1140+
__VA_ARGS__) \
1141+
default: \
1142+
FLASHINFER_CHECK(false, "Unsupported n_ranks"); \
1143+
return cudaErrorNotSupported; \
1144+
} \
1145+
}()
1146+
11061147
template <typename T>
11071148
cudaError_t moereduction_allreduce_fusion_op(MoeReductionAllReduceFusionParams<T> const& params,
11081149
bool launch_with_pdl) {
@@ -1119,72 +1160,6 @@ cudaError_t moereduction_allreduce_fusion_op(MoeReductionAllReduceFusionParams<T
11191160
params.allreduce_out || params.residual_out || params.norm_out || params.quant_out,
11201161
"at least one of allreduce_out, residual_out, norm_out, quant_out must be set");
11211162

1122-
#define DISPATCH_MOEREDUCTION_KERNEL_RANKS(T, PARAMS, LAUNCH_WITH_PDL, AR, RES, RMS, QUANT) \
1123-
do { \
1124-
switch ((PARAMS).nranks) { \
1125-
case 2: \
1126-
FLASHINFER_CUDA_CALL( \
1127-
(moereduction_allreduce_fusion_kernel_launcher<T, 2, AR, RES, RMS, QUANT>( \
1128-
(PARAMS), (LAUNCH_WITH_PDL)))); \
1129-
break; \
1130-
case 4: \
1131-
FLASHINFER_CUDA_CALL( \
1132-
(moereduction_allreduce_fusion_kernel_launcher<T, 4, AR, RES, RMS, QUANT>( \
1133-
(PARAMS), (LAUNCH_WITH_PDL)))); \
1134-
break; \
1135-
case 8: \
1136-
FLASHINFER_CUDA_CALL( \
1137-
(moereduction_allreduce_fusion_kernel_launcher<T, 8, AR, RES, RMS, QUANT>( \
1138-
(PARAMS), (LAUNCH_WITH_PDL)))); \
1139-
break; \
1140-
case 16: \
1141-
FLASHINFER_CUDA_CALL( \
1142-
(moereduction_allreduce_fusion_kernel_launcher<T, 16, AR, RES, RMS, QUANT>( \
1143-
(PARAMS), (LAUNCH_WITH_PDL)))); \
1144-
break; \
1145-
default: \
1146-
FLASHINFER_CHECK(false, "Unsupported nranks"); \
1147-
} \
1148-
return cudaSuccess; \
1149-
} while (0)
1150-
1151-
#define DISPATCH_MOEREDUCTION_PATTERN(T, PARAMS, LAUNCH_WITH_PDL) \
1152-
do { \
1153-
bool AR = (PARAMS).allreduce_out; \
1154-
bool RES = (PARAMS).residual_out; \
1155-
bool NORM = (PARAMS).norm_out; \
1156-
bool QUANT = (PARAMS).quant_out; \
1157-
\
1158-
if (AR && RES && !NORM && QUANT) { \
1159-
/* AR + Residual + Quant */ \
1160-
DISPATCH_MOEREDUCTION_KERNEL_RANKS(T, PARAMS, LAUNCH_WITH_PDL, true, true, false, true); \
1161-
} else if (!AR && RES && !NORM && QUANT) { \
1162-
/* Residual + Quant */ \
1163-
DISPATCH_MOEREDUCTION_KERNEL_RANKS(T, PARAMS, LAUNCH_WITH_PDL, false, true, false, true); \
1164-
} else if (AR && !RES && NORM && !QUANT) { \
1165-
/* AR + RMS */ \
1166-
DISPATCH_MOEREDUCTION_KERNEL_RANKS(T, PARAMS, LAUNCH_WITH_PDL, true, false, true, false); \
1167-
} else if (!AR && !RES && NORM && !QUANT) { \
1168-
/* RMS only */ \
1169-
DISPATCH_MOEREDUCTION_KERNEL_RANKS(T, PARAMS, LAUNCH_WITH_PDL, false, false, true, false); \
1170-
} else if (AR && RES && NORM && !QUANT) { \
1171-
/* AR + Add + RMS */ \
1172-
DISPATCH_MOEREDUCTION_KERNEL_RANKS(T, PARAMS, LAUNCH_WITH_PDL, true, true, true, false); \
1173-
} else if (!AR && RES && NORM && !QUANT) { \
1174-
/* Add + RMS */ \
1175-
DISPATCH_MOEREDUCTION_KERNEL_RANKS(T, PARAMS, LAUNCH_WITH_PDL, false, true, true, false); \
1176-
} else if (AR && RES && NORM && QUANT) { \
1177-
/* AR + Add + RMS + Quant */ \
1178-
DISPATCH_MOEREDUCTION_KERNEL_RANKS(T, PARAMS, LAUNCH_WITH_PDL, true, true, true, true); \
1179-
} else if (!AR && RES && NORM && QUANT) { \
1180-
/* Add + RMS + Quant */ \
1181-
DISPATCH_MOEREDUCTION_KERNEL_RANKS(T, PARAMS, LAUNCH_WITH_PDL, false, true, true, true); \
1182-
} else { \
1183-
FLASHINFER_CHECK(false, "allreduce_fusion_kernel: unsupported pattern!"); \
1184-
return cudaErrorNotSupported; \
1185-
} \
1186-
} while (0)
1187-
11881163
// hidden_dim (d) = 7168 for dpsk moe, and hence 128 tokens as one-shot threshold
11891164
// AR outputs are optional, since we always have fused options followed.
11901165
// pattern1: AR+Residual+Add_RMS+Quant
@@ -1203,10 +1178,14 @@ cudaError_t moereduction_allreduce_fusion_op(MoeReductionAllReduceFusionParams<T
12031178
// [m, d] bf16 allreduce_in, [m, d] bf16 residual_in
12041179
// [m, d] bf16 residual_out, [m, d] bf16 norm_out, [m, d] fp4 quant_out
12051180

1206-
DISPATCH_MOEREDUCTION_PATTERN(T, params, launch_with_pdl);
1207-
1208-
FLASHINFER_CHECK(false, "allreduce_fusion_kernel: unsupported pattern!");
1209-
return cudaErrorNotSupported;
1181+
auto status = DISPATCH_MOEREDUCTION(
1182+
params.nranks, params.allreduce_out, params.residual_out, params.rms_gamma, params.quant_out,
1183+
N_RANKS, AR, RES, RMS, QUANT, [&]() -> cudaError_t {
1184+
FLASHINFER_CUDA_CALL(
1185+
(moereduction_allreduce_fusion_kernel_launcher<T, N_RANKS, AR, RES, RMS, QUANT>(
1186+
(params), (launch_with_pdl))));
1187+
});
1188+
return status;
12101189
}
12111190
} // namespace trtllm_moe_allreduce_fusion
12121191
} // namespace flashinfer

0 commit comments

Comments
 (0)